Initial commit

This commit is contained in:
Ria Bhatia
2017-12-04 13:32:57 -06:00
committed by Erik St. Martin
commit 0075e5b0f3
9056 changed files with 2523100 additions and 0 deletions

11
vendor/github.com/hyperhq/hypercli/pkg/README.md generated vendored Normal file
View File

@@ -0,0 +1,11 @@
pkg/ is a collection of utility packages used by the Docker project without being specific to its internals.
Utility packages are kept separate from the docker core codebase to keep it as small and concise as possible.
If some utilities grow larger and their APIs stabilize, they may be moved to their own repository under the
Docker organization, to facilitate re-use by other projects. However that is not the priority.
The directory `pkg` is named after the same directory in the camlistore project. Since Brad is a core
Go maintainer, we thought it made sense to copy his methods for organizing Go code :) Thanks Brad!
Because utility packages are small and neatly separated from the rest of the codebase, they are a good
place to start for aspiring maintainers and contributors. Get in touch if you want to help maintain them!

View File

@@ -0,0 +1,76 @@
// Package aaparser is a convenience package interacting with `apparmor_parser`.
package aaparser
import (
"fmt"
"os/exec"
"path/filepath"
"strconv"
"strings"
)
const (
binary = "apparmor_parser"
)
// GetVersion returns the major and minor version of apparmor_parser.
func GetVersion() (int, int, error) {
output, err := cmd("", "--version")
if err != nil {
return -1, -1, err
}
return parseVersion(string(output))
}
// LoadProfile runs `apparmor_parser -r -W` on a specified apparmor profile to
// replace and write it to disk.
func LoadProfile(profilePath string) error {
_, err := cmd(filepath.Dir(profilePath), "-r", "-W", filepath.Base(profilePath))
if err != nil {
return err
}
return nil
}
// cmd runs `apparmor_parser` with the passed arguments.
func cmd(dir string, arg ...string) (string, error) {
c := exec.Command(binary, arg...)
c.Dir = dir
output, err := c.CombinedOutput()
if err != nil {
return "", fmt.Errorf("running `%s %s` failed with output: %s\nerror: %v", c.Path, strings.Join(c.Args, " "), string(output), err)
}
return string(output), nil
}
// parseVersion takes the output from `apparmor_parser --version` and returns
// the major and minor version for `apparor_parser`.
func parseVersion(output string) (int, int, error) {
// output is in the form of the following:
// AppArmor parser version 2.9.1
// Copyright (C) 1999-2008 Novell Inc.
// Copyright 2009-2012 Canonical Ltd.
lines := strings.SplitN(output, "\n", 2)
words := strings.Split(lines[0], " ")
version := words[len(words)-1]
// split by major minor version
v := strings.Split(version, ".")
if len(v) < 2 {
return -1, -1, fmt.Errorf("parsing major minor version failed for output: `%s`", output)
}
majorVersion, err := strconv.Atoi(v[0])
if err != nil {
return -1, -1, err
}
minorVersion, err := strconv.Atoi(v[1])
if err != nil {
return -1, -1, err
}
return majorVersion, minorVersion, nil
}

View File

@@ -0,0 +1,65 @@
package aaparser
import (
"testing"
)
type versionExpected struct {
output string
major int
minor int
}
func TestParseVersion(t *testing.T) {
versions := []versionExpected{
{
output: `AppArmor parser version 2.10
Copyright (C) 1999-2008 Novell Inc.
Copyright 2009-2012 Canonical Ltd.
`,
major: 2,
minor: 10,
},
{
output: `AppArmor parser version 2.8
Copyright (C) 1999-2008 Novell Inc.
Copyright 2009-2012 Canonical Ltd.
`,
major: 2,
minor: 8,
},
{
output: `AppArmor parser version 2.20
Copyright (C) 1999-2008 Novell Inc.
Copyright 2009-2012 Canonical Ltd.
`,
major: 2,
minor: 20,
},
{
output: `AppArmor parser version 2.05
Copyright (C) 1999-2008 Novell Inc.
Copyright 2009-2012 Canonical Ltd.
`,
major: 2,
minor: 5,
},
}
for _, v := range versions {
major, minor, err := parseVersion(v.output)
if err != nil {
t.Fatalf("expected error to be nil for %#v, got: %v", v, err)
}
if major != v.major {
t.Fatalf("expected major version to be %d, was %d, for: %#v\n", v.major, major, v)
}
if minor != v.minor {
t.Fatalf("expected minor version to be %d, was %d, for: %#v\n", v.minor, minor, v)
}
}
}

View File

@@ -0,0 +1 @@
This code provides helper functions for dealing with archive files.

1049
vendor/github.com/hyperhq/hypercli/pkg/archive/archive.go generated vendored Normal file

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,112 @@
// +build !windows
package archive
import (
"archive/tar"
"errors"
"os"
"path/filepath"
"syscall"
"github.com/hyperhq/hypercli/pkg/system"
)
// fixVolumePathPrefix does platform specific processing to ensure that if
// the path being passed in is not in a volume path format, convert it to one.
func fixVolumePathPrefix(srcPath string) string {
return srcPath
}
// getWalkRoot calculates the root path when performing a TarWithOptions.
// We use a separate function as this is platform specific. On Linux, we
// can't use filepath.Join(srcPath,include) because this will clean away
// a trailing "." or "/" which may be important.
func getWalkRoot(srcPath string, include string) string {
return srcPath + string(filepath.Separator) + include
}
// CanonicalTarNameForPath returns platform-specific filepath
// to canonical posix-style path for tar archival. p is relative
// path.
func CanonicalTarNameForPath(p string) (string, error) {
return p, nil // already unix-style
}
// chmodTarEntry is used to adjust the file permissions used in tar header based
// on the platform the archival is done.
func chmodTarEntry(perm os.FileMode) os.FileMode {
return perm // noop for unix as golang APIs provide perm bits correctly
}
func setHeaderForSpecialDevice(hdr *tar.Header, ta *tarAppender, name string, stat interface{}) (inode uint64, err error) {
s, ok := stat.(*syscall.Stat_t)
if !ok {
err = errors.New("cannot convert stat value to syscall.Stat_t")
return
}
inode = uint64(s.Ino)
// Currently go does not fill in the major/minors
if s.Mode&syscall.S_IFBLK != 0 ||
s.Mode&syscall.S_IFCHR != 0 {
hdr.Devmajor = int64(major(uint64(s.Rdev)))
hdr.Devminor = int64(minor(uint64(s.Rdev)))
}
return
}
func getFileUIDGID(stat interface{}) (int, int, error) {
s, ok := stat.(*syscall.Stat_t)
if !ok {
return -1, -1, errors.New("cannot convert stat value to syscall.Stat_t")
}
return int(s.Uid), int(s.Gid), nil
}
func major(device uint64) uint64 {
return (device >> 8) & 0xfff
}
func minor(device uint64) uint64 {
return (device & 0xff) | ((device >> 12) & 0xfff00)
}
// handleTarTypeBlockCharFifo is an OS-specific helper function used by
// createTarFile to handle the following types of header: Block; Char; Fifo
func handleTarTypeBlockCharFifo(hdr *tar.Header, path string) error {
mode := uint32(hdr.Mode & 07777)
switch hdr.Typeflag {
case tar.TypeBlock:
mode |= syscall.S_IFBLK
case tar.TypeChar:
mode |= syscall.S_IFCHR
case tar.TypeFifo:
mode |= syscall.S_IFIFO
}
if err := system.Mknod(path, mode, int(system.Mkdev(hdr.Devmajor, hdr.Devminor))); err != nil {
return err
}
return nil
}
func handleLChmod(hdr *tar.Header, path string, hdrInfo os.FileInfo) error {
if hdr.Typeflag == tar.TypeLink {
if fi, err := os.Lstat(hdr.Linkname); err == nil && (fi.Mode()&os.ModeSymlink == 0) {
if err := os.Chmod(path, hdrInfo.Mode()); err != nil {
return err
}
}
} else if hdr.Typeflag != tar.TypeSymlink {
if err := os.Chmod(path, hdrInfo.Mode()); err != nil {
return err
}
}
return nil
}

View File

@@ -0,0 +1,60 @@
// +build !windows
package archive
import (
"os"
"testing"
)
func TestCanonicalTarNameForPath(t *testing.T) {
cases := []struct{ in, expected string }{
{"foo", "foo"},
{"foo/bar", "foo/bar"},
{"foo/dir/", "foo/dir/"},
}
for _, v := range cases {
if out, err := CanonicalTarNameForPath(v.in); err != nil {
t.Fatalf("cannot get canonical name for path: %s: %v", v.in, err)
} else if out != v.expected {
t.Fatalf("wrong canonical tar name. expected:%s got:%s", v.expected, out)
}
}
}
func TestCanonicalTarName(t *testing.T) {
cases := []struct {
in string
isDir bool
expected string
}{
{"foo", false, "foo"},
{"foo", true, "foo/"},
{"foo/bar", false, "foo/bar"},
{"foo/bar", true, "foo/bar/"},
}
for _, v := range cases {
if out, err := canonicalTarName(v.in, v.isDir); err != nil {
t.Fatalf("cannot get canonical name for path: %s: %v", v.in, err)
} else if out != v.expected {
t.Fatalf("wrong canonical tar name. expected:%s got:%s", v.expected, out)
}
}
}
func TestChmodTarEntry(t *testing.T) {
cases := []struct {
in, expected os.FileMode
}{
{0000, 0000},
{0777, 0777},
{0644, 0644},
{0755, 0755},
{0444, 0444},
}
for _, v := range cases {
if out := chmodTarEntry(v.in); out != v.expected {
t.Fatalf("wrong chmod. expected:%v got:%v", v.expected, out)
}
}
}

View File

@@ -0,0 +1,70 @@
// +build windows
package archive
import (
"archive/tar"
"fmt"
"os"
"path/filepath"
"strings"
"github.com/hyperhq/hypercli/pkg/longpath"
)
// fixVolumePathPrefix does platform specific processing to ensure that if
// the path being passed in is not in a volume path format, convert it to one.
func fixVolumePathPrefix(srcPath string) string {
return longpath.AddPrefix(srcPath)
}
// getWalkRoot calculates the root path when performing a TarWithOptions.
// We use a separate function as this is platform specific.
func getWalkRoot(srcPath string, include string) string {
return filepath.Join(srcPath, include)
}
// CanonicalTarNameForPath returns platform-specific filepath
// to canonical posix-style path for tar archival. p is relative
// path.
func CanonicalTarNameForPath(p string) (string, error) {
// windows: convert windows style relative path with backslashes
// into forward slashes. Since windows does not allow '/' or '\'
// in file names, it is mostly safe to replace however we must
// check just in case
if strings.Contains(p, "/") {
return "", fmt.Errorf("Windows path contains forward slash: %s", p)
}
return strings.Replace(p, string(os.PathSeparator), "/", -1), nil
}
// chmodTarEntry is used to adjust the file permissions used in tar header based
// on the platform the archival is done.
func chmodTarEntry(perm os.FileMode) os.FileMode {
perm &= 0755
// Add the x bit: make everything +x from windows
perm |= 0111
return perm
}
func setHeaderForSpecialDevice(hdr *tar.Header, ta *tarAppender, name string, stat interface{}) (inode uint64, err error) {
// do nothing. no notion of Rdev, Inode, Nlink in stat on Windows
return
}
// handleTarTypeBlockCharFifo is an OS-specific helper function used by
// createTarFile to handle the following types of header: Block; Char; Fifo
func handleTarTypeBlockCharFifo(hdr *tar.Header, path string) error {
return nil
}
func handleLChmod(hdr *tar.Header, path string, hdrInfo os.FileInfo) error {
return nil
}
func getFileUIDGID(stat interface{}) (int, int, error) {
// no notion of file ownership mapping yet on Windows
return 0, 0, nil
}

View File

@@ -0,0 +1,87 @@
// +build windows
package archive
import (
"io/ioutil"
"os"
"path/filepath"
"testing"
)
func TestCopyFileWithInvalidDest(t *testing.T) {
folder, err := ioutil.TempDir("", "docker-archive-test")
if err != nil {
t.Fatal(err)
}
defer os.RemoveAll(folder)
dest := "c:dest"
srcFolder := filepath.Join(folder, "src")
src := filepath.Join(folder, "src", "src")
err = os.MkdirAll(srcFolder, 0740)
if err != nil {
t.Fatal(err)
}
ioutil.WriteFile(src, []byte("content"), 0777)
err = CopyWithTar(src, dest)
if err == nil {
t.Fatalf("archiver.CopyWithTar should throw an error on invalid dest.")
}
}
func TestCanonicalTarNameForPath(t *testing.T) {
cases := []struct {
in, expected string
shouldFail bool
}{
{"foo", "foo", false},
{"foo/bar", "___", true}, // unix-styled windows path must fail
{`foo\bar`, "foo/bar", false},
}
for _, v := range cases {
if out, err := CanonicalTarNameForPath(v.in); err != nil && !v.shouldFail {
t.Fatalf("cannot get canonical name for path: %s: %v", v.in, err)
} else if v.shouldFail && err == nil {
t.Fatalf("canonical path call should have failed with error. in=%s out=%s", v.in, out)
} else if !v.shouldFail && out != v.expected {
t.Fatalf("wrong canonical tar name. expected:%s got:%s", v.expected, out)
}
}
}
func TestCanonicalTarName(t *testing.T) {
cases := []struct {
in string
isDir bool
expected string
}{
{"foo", false, "foo"},
{"foo", true, "foo/"},
{`foo\bar`, false, "foo/bar"},
{`foo\bar`, true, "foo/bar/"},
}
for _, v := range cases {
if out, err := canonicalTarName(v.in, v.isDir); err != nil {
t.Fatalf("cannot get canonical name for path: %s: %v", v.in, err)
} else if out != v.expected {
t.Fatalf("wrong canonical tar name. expected:%s got:%s", v.expected, out)
}
}
}
func TestChmodTarEntry(t *testing.T) {
cases := []struct {
in, expected os.FileMode
}{
{0000, 0111},
{0777, 0755},
{0644, 0755},
{0755, 0755},
{0444, 0555},
}
for _, v := range cases {
if out := chmodTarEntry(v.in); out != v.expected {
t.Fatalf("wrong chmod. expected:%v got:%v", v.expected, out)
}
}
}

View File

@@ -0,0 +1,416 @@
package archive
import (
"archive/tar"
"bytes"
"fmt"
"io"
"io/ioutil"
"os"
"path/filepath"
"sort"
"strings"
"syscall"
"time"
"github.com/Sirupsen/logrus"
"github.com/hyperhq/hypercli/pkg/idtools"
"github.com/hyperhq/hypercli/pkg/pools"
"github.com/hyperhq/hypercli/pkg/system"
)
// ChangeType represents the change type.
type ChangeType int
const (
// ChangeModify represents the modify operation.
ChangeModify = iota
// ChangeAdd represents the add operation.
ChangeAdd
// ChangeDelete represents the delete operation.
ChangeDelete
)
func (c ChangeType) String() string {
switch c {
case ChangeModify:
return "C"
case ChangeAdd:
return "A"
case ChangeDelete:
return "D"
}
return ""
}
// Change represents a change, it wraps the change type and path.
// It describes changes of the files in the path respect to the
// parent layers. The change could be modify, add, delete.
// This is used for layer diff.
type Change struct {
Path string
Kind ChangeType
}
func (change *Change) String() string {
return fmt.Sprintf("%s %s", change.Kind, change.Path)
}
// for sort.Sort
type changesByPath []Change
func (c changesByPath) Less(i, j int) bool { return c[i].Path < c[j].Path }
func (c changesByPath) Len() int { return len(c) }
func (c changesByPath) Swap(i, j int) { c[j], c[i] = c[i], c[j] }
// Gnu tar and the go tar writer don't have sub-second mtime
// precision, which is problematic when we apply changes via tar
// files, we handle this by comparing for exact times, *or* same
// second count and either a or b having exactly 0 nanoseconds
func sameFsTime(a, b time.Time) bool {
return a == b ||
(a.Unix() == b.Unix() &&
(a.Nanosecond() == 0 || b.Nanosecond() == 0))
}
func sameFsTimeSpec(a, b syscall.Timespec) bool {
return a.Sec == b.Sec &&
(a.Nsec == b.Nsec || a.Nsec == 0 || b.Nsec == 0)
}
// Changes walks the path rw and determines changes for the files in the path,
// with respect to the parent layers
func Changes(layers []string, rw string) ([]Change, error) {
var (
changes []Change
changedDirs = make(map[string]struct{})
)
err := filepath.Walk(rw, func(path string, f os.FileInfo, err error) error {
if err != nil {
return err
}
// Rebase path
path, err = filepath.Rel(rw, path)
if err != nil {
return err
}
// As this runs on the daemon side, file paths are OS specific.
path = filepath.Join(string(os.PathSeparator), path)
// Skip root
if path == string(os.PathSeparator) {
return nil
}
// Skip AUFS metadata
if matched, err := filepath.Match(string(os.PathSeparator)+WhiteoutMetaPrefix+"*", path); err != nil || matched {
return err
}
change := Change{
Path: path,
}
// Find out what kind of modification happened
file := filepath.Base(path)
// If there is a whiteout, then the file was removed
if strings.HasPrefix(file, WhiteoutPrefix) {
originalFile := file[len(WhiteoutPrefix):]
change.Path = filepath.Join(filepath.Dir(path), originalFile)
change.Kind = ChangeDelete
} else {
// Otherwise, the file was added
change.Kind = ChangeAdd
// ...Unless it already existed in a top layer, in which case, it's a modification
for _, layer := range layers {
stat, err := os.Stat(filepath.Join(layer, path))
if err != nil && !os.IsNotExist(err) {
return err
}
if err == nil {
// The file existed in the top layer, so that's a modification
// However, if it's a directory, maybe it wasn't actually modified.
// If you modify /foo/bar/baz, then /foo will be part of the changed files only because it's the parent of bar
if stat.IsDir() && f.IsDir() {
if f.Size() == stat.Size() && f.Mode() == stat.Mode() && sameFsTime(f.ModTime(), stat.ModTime()) {
// Both directories are the same, don't record the change
return nil
}
}
change.Kind = ChangeModify
break
}
}
}
// If /foo/bar/file.txt is modified, then /foo/bar must be part of the changed files.
// This block is here to ensure the change is recorded even if the
// modify time, mode and size of the parent directory in the rw and ro layers are all equal.
// Check https://github.com/hyperhq/hypercli/pull/13590 for details.
if f.IsDir() {
changedDirs[path] = struct{}{}
}
if change.Kind == ChangeAdd || change.Kind == ChangeDelete {
parent := filepath.Dir(path)
if _, ok := changedDirs[parent]; !ok && parent != "/" {
changes = append(changes, Change{Path: parent, Kind: ChangeModify})
changedDirs[parent] = struct{}{}
}
}
// Record change
changes = append(changes, change)
return nil
})
if err != nil && !os.IsNotExist(err) {
return nil, err
}
return changes, nil
}
// FileInfo describes the information of a file.
type FileInfo struct {
parent *FileInfo
name string
stat *system.StatT
children map[string]*FileInfo
capability []byte
added bool
}
// LookUp looks up the file information of a file.
func (info *FileInfo) LookUp(path string) *FileInfo {
// As this runs on the daemon side, file paths are OS specific.
parent := info
if path == string(os.PathSeparator) {
return info
}
pathElements := strings.Split(path, string(os.PathSeparator))
for _, elem := range pathElements {
if elem != "" {
child := parent.children[elem]
if child == nil {
return nil
}
parent = child
}
}
return parent
}
func (info *FileInfo) path() string {
if info.parent == nil {
// As this runs on the daemon side, file paths are OS specific.
return string(os.PathSeparator)
}
return filepath.Join(info.parent.path(), info.name)
}
func (info *FileInfo) addChanges(oldInfo *FileInfo, changes *[]Change) {
sizeAtEntry := len(*changes)
if oldInfo == nil {
// add
change := Change{
Path: info.path(),
Kind: ChangeAdd,
}
*changes = append(*changes, change)
info.added = true
}
// We make a copy so we can modify it to detect additions
// also, we only recurse on the old dir if the new info is a directory
// otherwise any previous delete/change is considered recursive
oldChildren := make(map[string]*FileInfo)
if oldInfo != nil && info.isDir() {
for k, v := range oldInfo.children {
oldChildren[k] = v
}
}
for name, newChild := range info.children {
oldChild, _ := oldChildren[name]
if oldChild != nil {
// change?
oldStat := oldChild.stat
newStat := newChild.stat
// Note: We can't compare inode or ctime or blocksize here, because these change
// when copying a file into a container. However, that is not generally a problem
// because any content change will change mtime, and any status change should
// be visible when actually comparing the stat fields. The only time this
// breaks down is if some code intentionally hides a change by setting
// back mtime
if statDifferent(oldStat, newStat) ||
bytes.Compare(oldChild.capability, newChild.capability) != 0 {
change := Change{
Path: newChild.path(),
Kind: ChangeModify,
}
*changes = append(*changes, change)
newChild.added = true
}
// Remove from copy so we can detect deletions
delete(oldChildren, name)
}
newChild.addChanges(oldChild, changes)
}
for _, oldChild := range oldChildren {
// delete
change := Change{
Path: oldChild.path(),
Kind: ChangeDelete,
}
*changes = append(*changes, change)
}
// If there were changes inside this directory, we need to add it, even if the directory
// itself wasn't changed. This is needed to properly save and restore filesystem permissions.
// As this runs on the daemon side, file paths are OS specific.
if len(*changes) > sizeAtEntry && info.isDir() && !info.added && info.path() != string(os.PathSeparator) {
change := Change{
Path: info.path(),
Kind: ChangeModify,
}
// Let's insert the directory entry before the recently added entries located inside this dir
*changes = append(*changes, change) // just to resize the slice, will be overwritten
copy((*changes)[sizeAtEntry+1:], (*changes)[sizeAtEntry:])
(*changes)[sizeAtEntry] = change
}
}
// Changes add changes to file information.
func (info *FileInfo) Changes(oldInfo *FileInfo) []Change {
var changes []Change
info.addChanges(oldInfo, &changes)
return changes
}
func newRootFileInfo() *FileInfo {
// As this runs on the daemon side, file paths are OS specific.
root := &FileInfo{
name: string(os.PathSeparator),
children: make(map[string]*FileInfo),
}
return root
}
// ChangesDirs compares two directories and generates an array of Change objects describing the changes.
// If oldDir is "", then all files in newDir will be Add-Changes.
func ChangesDirs(newDir, oldDir string) ([]Change, error) {
var (
oldRoot, newRoot *FileInfo
)
if oldDir == "" {
emptyDir, err := ioutil.TempDir("", "empty")
if err != nil {
return nil, err
}
defer os.Remove(emptyDir)
oldDir = emptyDir
}
oldRoot, newRoot, err := collectFileInfoForChanges(oldDir, newDir)
if err != nil {
return nil, err
}
return newRoot.Changes(oldRoot), nil
}
// ChangesSize calculates the size in bytes of the provided changes, based on newDir.
func ChangesSize(newDir string, changes []Change) int64 {
var (
size int64
sf = make(map[uint64]struct{})
)
for _, change := range changes {
if change.Kind == ChangeModify || change.Kind == ChangeAdd {
file := filepath.Join(newDir, change.Path)
fileInfo, err := os.Lstat(file)
if err != nil {
logrus.Errorf("Can not stat %q: %s", file, err)
continue
}
if fileInfo != nil && !fileInfo.IsDir() {
if hasHardlinks(fileInfo) {
inode := getIno(fileInfo)
if _, ok := sf[inode]; !ok {
size += fileInfo.Size()
sf[inode] = struct{}{}
}
} else {
size += fileInfo.Size()
}
}
}
}
return size
}
// ExportChanges produces an Archive from the provided changes, relative to dir.
func ExportChanges(dir string, changes []Change, uidMaps, gidMaps []idtools.IDMap) (Archive, error) {
reader, writer := io.Pipe()
go func() {
ta := &tarAppender{
TarWriter: tar.NewWriter(writer),
Buffer: pools.BufioWriter32KPool.Get(nil),
SeenFiles: make(map[uint64]string),
UIDMaps: uidMaps,
GIDMaps: gidMaps,
}
// this buffer is needed for the duration of this piped stream
defer pools.BufioWriter32KPool.Put(ta.Buffer)
sort.Sort(changesByPath(changes))
// In general we log errors here but ignore them because
// during e.g. a diff operation the container can continue
// mutating the filesystem and we can see transient errors
// from this
for _, change := range changes {
if change.Kind == ChangeDelete {
whiteOutDir := filepath.Dir(change.Path)
whiteOutBase := filepath.Base(change.Path)
whiteOut := filepath.Join(whiteOutDir, WhiteoutPrefix+whiteOutBase)
timestamp := time.Now()
hdr := &tar.Header{
Name: whiteOut[1:],
Size: 0,
ModTime: timestamp,
AccessTime: timestamp,
ChangeTime: timestamp,
}
if err := ta.TarWriter.WriteHeader(hdr); err != nil {
logrus.Debugf("Can't write whiteout header: %s", err)
}
} else {
path := filepath.Join(dir, change.Path)
if err := ta.addTarFile(path, change.Path[1:]); err != nil {
logrus.Debugf("Can't add file %s to tar: %s", path, err)
}
}
}
// Make sure to check the error on Close.
if err := ta.TarWriter.Close(); err != nil {
logrus.Debugf("Can't close layer: %s", err)
}
if err := writer.Close(); err != nil {
logrus.Debugf("failed close Changes writer: %s", err)
}
}()
return reader, nil
}

View File

@@ -0,0 +1,285 @@
package archive
import (
"bytes"
"fmt"
"os"
"path/filepath"
"sort"
"syscall"
"unsafe"
"github.com/hyperhq/hypercli/pkg/system"
)
// walker is used to implement collectFileInfoForChanges on linux. Where this
// method in general returns the entire contents of two directory trees, we
// optimize some FS calls out on linux. In particular, we take advantage of the
// fact that getdents(2) returns the inode of each file in the directory being
// walked, which, when walking two trees in parallel to generate a list of
// changes, can be used to prune subtrees without ever having to lstat(2) them
// directly. Eliminating stat calls in this way can save up to seconds on large
// images.
type walker struct {
dir1 string
dir2 string
root1 *FileInfo
root2 *FileInfo
}
// collectFileInfoForChanges returns a complete representation of the trees
// rooted at dir1 and dir2, with one important exception: any subtree or
// leaf where the inode and device numbers are an exact match between dir1
// and dir2 will be pruned from the results. This method is *only* to be used
// to generating a list of changes between the two directories, as it does not
// reflect the full contents.
func collectFileInfoForChanges(dir1, dir2 string) (*FileInfo, *FileInfo, error) {
w := &walker{
dir1: dir1,
dir2: dir2,
root1: newRootFileInfo(),
root2: newRootFileInfo(),
}
i1, err := os.Lstat(w.dir1)
if err != nil {
return nil, nil, err
}
i2, err := os.Lstat(w.dir2)
if err != nil {
return nil, nil, err
}
if err := w.walk("/", i1, i2); err != nil {
return nil, nil, err
}
return w.root1, w.root2, nil
}
// Given a FileInfo, its path info, and a reference to the root of the tree
// being constructed, register this file with the tree.
func walkchunk(path string, fi os.FileInfo, dir string, root *FileInfo) error {
if fi == nil {
return nil
}
parent := root.LookUp(filepath.Dir(path))
if parent == nil {
return fmt.Errorf("collectFileInfoForChanges: Unexpectedly no parent for %s", path)
}
info := &FileInfo{
name: filepath.Base(path),
children: make(map[string]*FileInfo),
parent: parent,
}
cpath := filepath.Join(dir, path)
stat, err := system.FromStatT(fi.Sys().(*syscall.Stat_t))
if err != nil {
return err
}
info.stat = stat
info.capability, _ = system.Lgetxattr(cpath, "security.capability") // lgetxattr(2): fs access
parent.children[info.name] = info
return nil
}
// Walk a subtree rooted at the same path in both trees being iterated. For
// example, /docker/overlay/1234/a/b/c/d and /docker/overlay/8888/a/b/c/d
func (w *walker) walk(path string, i1, i2 os.FileInfo) (err error) {
// Register these nodes with the return trees, unless we're still at the
// (already-created) roots:
if path != "/" {
if err := walkchunk(path, i1, w.dir1, w.root1); err != nil {
return err
}
if err := walkchunk(path, i2, w.dir2, w.root2); err != nil {
return err
}
}
is1Dir := i1 != nil && i1.IsDir()
is2Dir := i2 != nil && i2.IsDir()
sameDevice := false
if i1 != nil && i2 != nil {
si1 := i1.Sys().(*syscall.Stat_t)
si2 := i2.Sys().(*syscall.Stat_t)
if si1.Dev == si2.Dev {
sameDevice = true
}
}
// If these files are both non-existent, or leaves (non-dirs), we are done.
if !is1Dir && !is2Dir {
return nil
}
// Fetch the names of all the files contained in both directories being walked:
var names1, names2 []nameIno
if is1Dir {
names1, err = readdirnames(filepath.Join(w.dir1, path)) // getdents(2): fs access
if err != nil {
return err
}
}
if is2Dir {
names2, err = readdirnames(filepath.Join(w.dir2, path)) // getdents(2): fs access
if err != nil {
return err
}
}
// We have lists of the files contained in both parallel directories, sorted
// in the same order. Walk them in parallel, generating a unique merged list
// of all items present in either or both directories.
var names []string
ix1 := 0
ix2 := 0
for {
if ix1 >= len(names1) {
break
}
if ix2 >= len(names2) {
break
}
ni1 := names1[ix1]
ni2 := names2[ix2]
switch bytes.Compare([]byte(ni1.name), []byte(ni2.name)) {
case -1: // ni1 < ni2 -- advance ni1
// we will not encounter ni1 in names2
names = append(names, ni1.name)
ix1++
case 0: // ni1 == ni2
if ni1.ino != ni2.ino || !sameDevice {
names = append(names, ni1.name)
}
ix1++
ix2++
case 1: // ni1 > ni2 -- advance ni2
// we will not encounter ni2 in names1
names = append(names, ni2.name)
ix2++
}
}
for ix1 < len(names1) {
names = append(names, names1[ix1].name)
ix1++
}
for ix2 < len(names2) {
names = append(names, names2[ix2].name)
ix2++
}
// For each of the names present in either or both of the directories being
// iterated, stat the name under each root, and recurse the pair of them:
for _, name := range names {
fname := filepath.Join(path, name)
var cInfo1, cInfo2 os.FileInfo
if is1Dir {
cInfo1, err = os.Lstat(filepath.Join(w.dir1, fname)) // lstat(2): fs access
if err != nil && !os.IsNotExist(err) {
return err
}
}
if is2Dir {
cInfo2, err = os.Lstat(filepath.Join(w.dir2, fname)) // lstat(2): fs access
if err != nil && !os.IsNotExist(err) {
return err
}
}
if err = w.walk(fname, cInfo1, cInfo2); err != nil {
return err
}
}
return nil
}
// {name,inode} pairs used to support the early-pruning logic of the walker type
type nameIno struct {
name string
ino uint64
}
type nameInoSlice []nameIno
func (s nameInoSlice) Len() int { return len(s) }
func (s nameInoSlice) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
func (s nameInoSlice) Less(i, j int) bool { return s[i].name < s[j].name }
// readdirnames is a hacked-apart version of the Go stdlib code, exposing inode
// numbers further up the stack when reading directory contents. Unlike
// os.Readdirnames, which returns a list of filenames, this function returns a
// list of {filename,inode} pairs.
func readdirnames(dirname string) (names []nameIno, err error) {
var (
size = 100
buf = make([]byte, 4096)
nbuf int
bufp int
nb int
)
f, err := os.Open(dirname)
if err != nil {
return nil, err
}
defer f.Close()
names = make([]nameIno, 0, size) // Empty with room to grow.
for {
// Refill the buffer if necessary
if bufp >= nbuf {
bufp = 0
nbuf, err = syscall.ReadDirent(int(f.Fd()), buf) // getdents on linux
if nbuf < 0 {
nbuf = 0
}
if err != nil {
return nil, os.NewSyscallError("readdirent", err)
}
if nbuf <= 0 {
break // EOF
}
}
// Drain the buffer
nb, names = parseDirent(buf[bufp:nbuf], names)
bufp += nb
}
sl := nameInoSlice(names)
sort.Sort(sl)
return sl, nil
}
// parseDirent is a minor modification of syscall.ParseDirent (linux version)
// which returns {name,inode} pairs instead of just names.
func parseDirent(buf []byte, names []nameIno) (consumed int, newnames []nameIno) {
origlen := len(buf)
for len(buf) > 0 {
dirent := (*syscall.Dirent)(unsafe.Pointer(&buf[0]))
buf = buf[dirent.Reclen:]
if dirent.Ino == 0 { // File absent in directory.
continue
}
bytes := (*[10000]byte)(unsafe.Pointer(&dirent.Name[0]))
var name = string(bytes[0:clen(bytes[:])])
if name == "." || name == ".." { // Useless names
continue
}
names = append(names, nameIno{name, dirent.Ino})
}
return origlen - len(buf), names
}
func clen(n []byte) int {
for i := 0; i < len(n); i++ {
if n[i] == 0 {
return i
}
}
return len(n)
}

View File

@@ -0,0 +1,97 @@
// +build !linux
package archive
import (
"fmt"
"os"
"path/filepath"
"runtime"
"strings"
"github.com/hyperhq/hypercli/pkg/system"
)
func collectFileInfoForChanges(oldDir, newDir string) (*FileInfo, *FileInfo, error) {
var (
oldRoot, newRoot *FileInfo
err1, err2 error
errs = make(chan error, 2)
)
go func() {
oldRoot, err1 = collectFileInfo(oldDir)
errs <- err1
}()
go func() {
newRoot, err2 = collectFileInfo(newDir)
errs <- err2
}()
// block until both routines have returned
for i := 0; i < 2; i++ {
if err := <-errs; err != nil {
return nil, nil, err
}
}
return oldRoot, newRoot, nil
}
func collectFileInfo(sourceDir string) (*FileInfo, error) {
root := newRootFileInfo()
err := filepath.Walk(sourceDir, func(path string, f os.FileInfo, err error) error {
if err != nil {
return err
}
// Rebase path
relPath, err := filepath.Rel(sourceDir, path)
if err != nil {
return err
}
// As this runs on the daemon side, file paths are OS specific.
relPath = filepath.Join(string(os.PathSeparator), relPath)
// See https://github.com/golang/go/issues/9168 - bug in filepath.Join.
// Temporary workaround. If the returned path starts with two backslashes,
// trim it down to a single backslash. Only relevant on Windows.
if runtime.GOOS == "windows" {
if strings.HasPrefix(relPath, `\\`) {
relPath = relPath[1:]
}
}
if relPath == string(os.PathSeparator) {
return nil
}
parent := root.LookUp(filepath.Dir(relPath))
if parent == nil {
return fmt.Errorf("collectFileInfo: Unexpectedly no parent for %s", relPath)
}
info := &FileInfo{
name: filepath.Base(relPath),
children: make(map[string]*FileInfo),
parent: parent,
}
s, err := system.Lstat(path)
if err != nil {
return err
}
info.stat = s
info.capability, _ = system.Lgetxattr(path, "security.capability")
parent.children[info.name] = info
return nil
})
if err != nil {
return nil, err
}
return root, nil
}

View File

@@ -0,0 +1,127 @@
package archive
import (
"archive/tar"
"fmt"
"io"
"io/ioutil"
"os"
"path"
"sort"
"testing"
)
func TestHardLinkOrder(t *testing.T) {
names := []string{"file1.txt", "file2.txt", "file3.txt"}
msg := []byte("Hey y'all")
// Create dir
src, err := ioutil.TempDir("", "docker-hardlink-test-src-")
if err != nil {
t.Fatal(err)
}
//defer os.RemoveAll(src)
for _, name := range names {
func() {
fh, err := os.Create(path.Join(src, name))
if err != nil {
t.Fatal(err)
}
defer fh.Close()
if _, err = fh.Write(msg); err != nil {
t.Fatal(err)
}
}()
}
// Create dest, with changes that includes hardlinks
dest, err := ioutil.TempDir("", "docker-hardlink-test-dest-")
if err != nil {
t.Fatal(err)
}
os.RemoveAll(dest) // we just want the name, at first
if err := copyDir(src, dest); err != nil {
t.Fatal(err)
}
defer os.RemoveAll(dest)
for _, name := range names {
for i := 0; i < 5; i++ {
if err := os.Link(path.Join(dest, name), path.Join(dest, fmt.Sprintf("%s.link%d", name, i))); err != nil {
t.Fatal(err)
}
}
}
// get changes
changes, err := ChangesDirs(dest, src)
if err != nil {
t.Fatal(err)
}
// sort
sort.Sort(changesByPath(changes))
// ExportChanges
ar, err := ExportChanges(dest, changes, nil, nil)
if err != nil {
t.Fatal(err)
}
hdrs, err := walkHeaders(ar)
if err != nil {
t.Fatal(err)
}
// reverse sort
sort.Sort(sort.Reverse(changesByPath(changes)))
// ExportChanges
arRev, err := ExportChanges(dest, changes, nil, nil)
if err != nil {
t.Fatal(err)
}
hdrsRev, err := walkHeaders(arRev)
if err != nil {
t.Fatal(err)
}
// line up the two sets
sort.Sort(tarHeaders(hdrs))
sort.Sort(tarHeaders(hdrsRev))
// compare Size and LinkName
for i := range hdrs {
if hdrs[i].Name != hdrsRev[i].Name {
t.Errorf("headers - expected name %q; but got %q", hdrs[i].Name, hdrsRev[i].Name)
}
if hdrs[i].Size != hdrsRev[i].Size {
t.Errorf("headers - %q expected size %d; but got %d", hdrs[i].Name, hdrs[i].Size, hdrsRev[i].Size)
}
if hdrs[i].Typeflag != hdrsRev[i].Typeflag {
t.Errorf("headers - %q expected type %d; but got %d", hdrs[i].Name, hdrs[i].Typeflag, hdrsRev[i].Typeflag)
}
if hdrs[i].Linkname != hdrsRev[i].Linkname {
t.Errorf("headers - %q expected linkname %q; but got %q", hdrs[i].Name, hdrs[i].Linkname, hdrsRev[i].Linkname)
}
}
}
type tarHeaders []tar.Header
func (th tarHeaders) Len() int { return len(th) }
func (th tarHeaders) Swap(i, j int) { th[j], th[i] = th[i], th[j] }
func (th tarHeaders) Less(i, j int) bool { return th[i].Name < th[j].Name }
func walkHeaders(r io.Reader) ([]tar.Header, error) {
t := tar.NewReader(r)
headers := []tar.Header{}
for {
hdr, err := t.Next()
if err != nil {
if err == io.EOF {
break
}
return headers, err
}
headers = append(headers, *hdr)
}
return headers, nil
}

View File

@@ -0,0 +1,527 @@
package archive
import (
"io/ioutil"
"os"
"os/exec"
"path"
"sort"
"testing"
"time"
)
func max(x, y int) int {
if x >= y {
return x
}
return y
}
func copyDir(src, dst string) error {
cmd := exec.Command("cp", "-a", src, dst)
if err := cmd.Run(); err != nil {
return err
}
return nil
}
type FileType uint32
const (
Regular FileType = iota
Dir
Symlink
)
type FileData struct {
filetype FileType
path string
contents string
permissions os.FileMode
}
func createSampleDir(t *testing.T, root string) {
files := []FileData{
{Regular, "file1", "file1\n", 0600},
{Regular, "file2", "file2\n", 0666},
{Regular, "file3", "file3\n", 0404},
{Regular, "file4", "file4\n", 0600},
{Regular, "file5", "file5\n", 0600},
{Regular, "file6", "file6\n", 0600},
{Regular, "file7", "file7\n", 0600},
{Dir, "dir1", "", 0740},
{Regular, "dir1/file1-1", "file1-1\n", 01444},
{Regular, "dir1/file1-2", "file1-2\n", 0666},
{Dir, "dir2", "", 0700},
{Regular, "dir2/file2-1", "file2-1\n", 0666},
{Regular, "dir2/file2-2", "file2-2\n", 0666},
{Dir, "dir3", "", 0700},
{Regular, "dir3/file3-1", "file3-1\n", 0666},
{Regular, "dir3/file3-2", "file3-2\n", 0666},
{Dir, "dir4", "", 0700},
{Regular, "dir4/file3-1", "file4-1\n", 0666},
{Regular, "dir4/file3-2", "file4-2\n", 0666},
{Symlink, "symlink1", "target1", 0666},
{Symlink, "symlink2", "target2", 0666},
{Symlink, "symlink3", root + "/file1", 0666},
{Symlink, "symlink4", root + "/symlink3", 0666},
{Symlink, "dirSymlink", root + "/dir1", 0740},
}
now := time.Now()
for _, info := range files {
p := path.Join(root, info.path)
if info.filetype == Dir {
if err := os.MkdirAll(p, info.permissions); err != nil {
t.Fatal(err)
}
} else if info.filetype == Regular {
if err := ioutil.WriteFile(p, []byte(info.contents), info.permissions); err != nil {
t.Fatal(err)
}
} else if info.filetype == Symlink {
if err := os.Symlink(info.contents, p); err != nil {
t.Fatal(err)
}
}
if info.filetype != Symlink {
// Set a consistent ctime, atime for all files and dirs
if err := os.Chtimes(p, now, now); err != nil {
t.Fatal(err)
}
}
}
}
func TestChangeString(t *testing.T) {
modifiyChange := Change{"change", ChangeModify}
toString := modifiyChange.String()
if toString != "C change" {
t.Fatalf("String() of a change with ChangeModifiy Kind should have been %s but was %s", "C change", toString)
}
addChange := Change{"change", ChangeAdd}
toString = addChange.String()
if toString != "A change" {
t.Fatalf("String() of a change with ChangeAdd Kind should have been %s but was %s", "A change", toString)
}
deleteChange := Change{"change", ChangeDelete}
toString = deleteChange.String()
if toString != "D change" {
t.Fatalf("String() of a change with ChangeDelete Kind should have been %s but was %s", "D change", toString)
}
}
func TestChangesWithNoChanges(t *testing.T) {
rwLayer, err := ioutil.TempDir("", "docker-changes-test")
if err != nil {
t.Fatal(err)
}
defer os.RemoveAll(rwLayer)
layer, err := ioutil.TempDir("", "docker-changes-test-layer")
if err != nil {
t.Fatal(err)
}
defer os.RemoveAll(layer)
createSampleDir(t, layer)
changes, err := Changes([]string{layer}, rwLayer)
if err != nil {
t.Fatal(err)
}
if len(changes) != 0 {
t.Fatalf("Changes with no difference should have detect no changes, but detected %d", len(changes))
}
}
func TestChangesWithChanges(t *testing.T) {
// Mock the readonly layer
layer, err := ioutil.TempDir("", "docker-changes-test-layer")
if err != nil {
t.Fatal(err)
}
defer os.RemoveAll(layer)
createSampleDir(t, layer)
os.MkdirAll(path.Join(layer, "dir1/subfolder"), 0740)
// Mock the RW layer
rwLayer, err := ioutil.TempDir("", "docker-changes-test")
if err != nil {
t.Fatal(err)
}
defer os.RemoveAll(rwLayer)
// Create a folder in RW layer
dir1 := path.Join(rwLayer, "dir1")
os.MkdirAll(dir1, 0740)
deletedFile := path.Join(dir1, ".wh.file1-2")
ioutil.WriteFile(deletedFile, []byte{}, 0600)
modifiedFile := path.Join(dir1, "file1-1")
ioutil.WriteFile(modifiedFile, []byte{0x00}, 01444)
// Let's add a subfolder for a newFile
subfolder := path.Join(dir1, "subfolder")
os.MkdirAll(subfolder, 0740)
newFile := path.Join(subfolder, "newFile")
ioutil.WriteFile(newFile, []byte{}, 0740)
changes, err := Changes([]string{layer}, rwLayer)
if err != nil {
t.Fatal(err)
}
expectedChanges := []Change{
{"/dir1", ChangeModify},
{"/dir1/file1-1", ChangeModify},
{"/dir1/file1-2", ChangeDelete},
{"/dir1/subfolder", ChangeModify},
{"/dir1/subfolder/newFile", ChangeAdd},
}
checkChanges(expectedChanges, changes, t)
}
// See https://github.com/hyperhq/hypercli/pull/13590
func TestChangesWithChangesGH13590(t *testing.T) {
baseLayer, err := ioutil.TempDir("", "docker-changes-test.")
defer os.RemoveAll(baseLayer)
dir3 := path.Join(baseLayer, "dir1/dir2/dir3")
os.MkdirAll(dir3, 07400)
file := path.Join(dir3, "file.txt")
ioutil.WriteFile(file, []byte("hello"), 0666)
layer, err := ioutil.TempDir("", "docker-changes-test2.")
defer os.RemoveAll(layer)
// Test creating a new file
if err := copyDir(baseLayer+"/dir1", layer+"/"); err != nil {
t.Fatalf("Cmd failed: %q", err)
}
os.Remove(path.Join(layer, "dir1/dir2/dir3/file.txt"))
file = path.Join(layer, "dir1/dir2/dir3/file1.txt")
ioutil.WriteFile(file, []byte("bye"), 0666)
changes, err := Changes([]string{baseLayer}, layer)
if err != nil {
t.Fatal(err)
}
expectedChanges := []Change{
{"/dir1/dir2/dir3", ChangeModify},
{"/dir1/dir2/dir3/file1.txt", ChangeAdd},
}
checkChanges(expectedChanges, changes, t)
// Now test changing a file
layer, err = ioutil.TempDir("", "docker-changes-test3.")
defer os.RemoveAll(layer)
if err := copyDir(baseLayer+"/dir1", layer+"/"); err != nil {
t.Fatalf("Cmd failed: %q", err)
}
file = path.Join(layer, "dir1/dir2/dir3/file.txt")
ioutil.WriteFile(file, []byte("bye"), 0666)
changes, err = Changes([]string{baseLayer}, layer)
if err != nil {
t.Fatal(err)
}
expectedChanges = []Change{
{"/dir1/dir2/dir3/file.txt", ChangeModify},
}
checkChanges(expectedChanges, changes, t)
}
// Create an directory, copy it, make sure we report no changes between the two
func TestChangesDirsEmpty(t *testing.T) {
src, err := ioutil.TempDir("", "docker-changes-test")
if err != nil {
t.Fatal(err)
}
defer os.RemoveAll(src)
createSampleDir(t, src)
dst := src + "-copy"
if err := copyDir(src, dst); err != nil {
t.Fatal(err)
}
defer os.RemoveAll(dst)
changes, err := ChangesDirs(dst, src)
if err != nil {
t.Fatal(err)
}
if len(changes) != 0 {
t.Fatalf("Reported changes for identical dirs: %v", changes)
}
os.RemoveAll(src)
os.RemoveAll(dst)
}
func mutateSampleDir(t *testing.T, root string) {
// Remove a regular file
if err := os.RemoveAll(path.Join(root, "file1")); err != nil {
t.Fatal(err)
}
// Remove a directory
if err := os.RemoveAll(path.Join(root, "dir1")); err != nil {
t.Fatal(err)
}
// Remove a symlink
if err := os.RemoveAll(path.Join(root, "symlink1")); err != nil {
t.Fatal(err)
}
// Rewrite a file
if err := ioutil.WriteFile(path.Join(root, "file2"), []byte("fileNN\n"), 0777); err != nil {
t.Fatal(err)
}
// Replace a file
if err := os.RemoveAll(path.Join(root, "file3")); err != nil {
t.Fatal(err)
}
if err := ioutil.WriteFile(path.Join(root, "file3"), []byte("fileMM\n"), 0404); err != nil {
t.Fatal(err)
}
// Touch file
if err := os.Chtimes(path.Join(root, "file4"), time.Now().Add(time.Second), time.Now().Add(time.Second)); err != nil {
t.Fatal(err)
}
// Replace file with dir
if err := os.RemoveAll(path.Join(root, "file5")); err != nil {
t.Fatal(err)
}
if err := os.MkdirAll(path.Join(root, "file5"), 0666); err != nil {
t.Fatal(err)
}
// Create new file
if err := ioutil.WriteFile(path.Join(root, "filenew"), []byte("filenew\n"), 0777); err != nil {
t.Fatal(err)
}
// Create new dir
if err := os.MkdirAll(path.Join(root, "dirnew"), 0766); err != nil {
t.Fatal(err)
}
// Create a new symlink
if err := os.Symlink("targetnew", path.Join(root, "symlinknew")); err != nil {
t.Fatal(err)
}
// Change a symlink
if err := os.RemoveAll(path.Join(root, "symlink2")); err != nil {
t.Fatal(err)
}
if err := os.Symlink("target2change", path.Join(root, "symlink2")); err != nil {
t.Fatal(err)
}
// Replace dir with file
if err := os.RemoveAll(path.Join(root, "dir2")); err != nil {
t.Fatal(err)
}
if err := ioutil.WriteFile(path.Join(root, "dir2"), []byte("dir2\n"), 0777); err != nil {
t.Fatal(err)
}
// Touch dir
if err := os.Chtimes(path.Join(root, "dir3"), time.Now().Add(time.Second), time.Now().Add(time.Second)); err != nil {
t.Fatal(err)
}
}
func TestChangesDirsMutated(t *testing.T) {
src, err := ioutil.TempDir("", "docker-changes-test")
if err != nil {
t.Fatal(err)
}
createSampleDir(t, src)
dst := src + "-copy"
if err := copyDir(src, dst); err != nil {
t.Fatal(err)
}
defer os.RemoveAll(src)
defer os.RemoveAll(dst)
mutateSampleDir(t, dst)
changes, err := ChangesDirs(dst, src)
if err != nil {
t.Fatal(err)
}
sort.Sort(changesByPath(changes))
expectedChanges := []Change{
{"/dir1", ChangeDelete},
{"/dir2", ChangeModify},
{"/dirnew", ChangeAdd},
{"/file1", ChangeDelete},
{"/file2", ChangeModify},
{"/file3", ChangeModify},
{"/file4", ChangeModify},
{"/file5", ChangeModify},
{"/filenew", ChangeAdd},
{"/symlink1", ChangeDelete},
{"/symlink2", ChangeModify},
{"/symlinknew", ChangeAdd},
}
for i := 0; i < max(len(changes), len(expectedChanges)); i++ {
if i >= len(expectedChanges) {
t.Fatalf("unexpected change %s\n", changes[i].String())
}
if i >= len(changes) {
t.Fatalf("no change for expected change %s\n", expectedChanges[i].String())
}
if changes[i].Path == expectedChanges[i].Path {
if changes[i] != expectedChanges[i] {
t.Fatalf("Wrong change for %s, expected %s, got %s\n", changes[i].Path, changes[i].String(), expectedChanges[i].String())
}
} else if changes[i].Path < expectedChanges[i].Path {
t.Fatalf("unexpected change %s\n", changes[i].String())
} else {
t.Fatalf("no change for expected change %s != %s\n", expectedChanges[i].String(), changes[i].String())
}
}
}
func TestApplyLayer(t *testing.T) {
src, err := ioutil.TempDir("", "docker-changes-test")
if err != nil {
t.Fatal(err)
}
createSampleDir(t, src)
defer os.RemoveAll(src)
dst := src + "-copy"
if err := copyDir(src, dst); err != nil {
t.Fatal(err)
}
mutateSampleDir(t, dst)
defer os.RemoveAll(dst)
changes, err := ChangesDirs(dst, src)
if err != nil {
t.Fatal(err)
}
layer, err := ExportChanges(dst, changes, nil, nil)
if err != nil {
t.Fatal(err)
}
layerCopy, err := NewTempArchive(layer, "")
if err != nil {
t.Fatal(err)
}
if _, err := ApplyLayer(src, layerCopy); err != nil {
t.Fatal(err)
}
changes2, err := ChangesDirs(src, dst)
if err != nil {
t.Fatal(err)
}
if len(changes2) != 0 {
t.Fatalf("Unexpected differences after reapplying mutation: %v", changes2)
}
}
func TestChangesSizeWithHardlinks(t *testing.T) {
srcDir, err := ioutil.TempDir("", "docker-test-srcDir")
if err != nil {
t.Fatal(err)
}
defer os.RemoveAll(srcDir)
destDir, err := ioutil.TempDir("", "docker-test-destDir")
if err != nil {
t.Fatal(err)
}
defer os.RemoveAll(destDir)
creationSize, err := prepareUntarSourceDirectory(100, destDir, true)
if err != nil {
t.Fatal(err)
}
changes, err := ChangesDirs(destDir, srcDir)
if err != nil {
t.Fatal(err)
}
got := ChangesSize(destDir, changes)
if got != int64(creationSize) {
t.Errorf("Expected %d bytes of changes, got %d", creationSize, got)
}
}
func TestChangesSizeWithNoChanges(t *testing.T) {
size := ChangesSize("/tmp", nil)
if size != 0 {
t.Fatalf("ChangesSizes with no changes should be 0, was %d", size)
}
}
func TestChangesSizeWithOnlyDeleteChanges(t *testing.T) {
changes := []Change{
{Path: "deletedPath", Kind: ChangeDelete},
}
size := ChangesSize("/tmp", changes)
if size != 0 {
t.Fatalf("ChangesSizes with only delete changes should be 0, was %d", size)
}
}
func TestChangesSize(t *testing.T) {
parentPath, err := ioutil.TempDir("", "docker-changes-test")
defer os.RemoveAll(parentPath)
addition := path.Join(parentPath, "addition")
if err := ioutil.WriteFile(addition, []byte{0x01, 0x01, 0x01}, 0744); err != nil {
t.Fatal(err)
}
modification := path.Join(parentPath, "modification")
if err = ioutil.WriteFile(modification, []byte{0x01, 0x01, 0x01}, 0744); err != nil {
t.Fatal(err)
}
changes := []Change{
{Path: "addition", Kind: ChangeAdd},
{Path: "modification", Kind: ChangeModify},
}
size := ChangesSize(parentPath, changes)
if size != 6 {
t.Fatalf("Expected 6 bytes of changes, got %d", size)
}
}
func checkChanges(expectedChanges, changes []Change, t *testing.T) {
sort.Sort(changesByPath(expectedChanges))
sort.Sort(changesByPath(changes))
for i := 0; i < max(len(changes), len(expectedChanges)); i++ {
if i >= len(expectedChanges) {
t.Fatalf("unexpected change %s\n", changes[i].String())
}
if i >= len(changes) {
t.Fatalf("no change for expected change %s\n", expectedChanges[i].String())
}
if changes[i].Path == expectedChanges[i].Path {
if changes[i] != expectedChanges[i] {
t.Fatalf("Wrong change for %s, expected %s, got %s\n", changes[i].Path, changes[i].String(), expectedChanges[i].String())
}
} else if changes[i].Path < expectedChanges[i].Path {
t.Fatalf("unexpected change %s\n", changes[i].String())
} else {
t.Fatalf("no change for expected change %s != %s\n", expectedChanges[i].String(), changes[i].String())
}
}
}

View File

@@ -0,0 +1,36 @@
// +build !windows
package archive
import (
"os"
"syscall"
"github.com/hyperhq/hypercli/pkg/system"
)
func statDifferent(oldStat *system.StatT, newStat *system.StatT) bool {
// Don't look at size for dirs, its not a good measure of change
if oldStat.Mode() != newStat.Mode() ||
oldStat.UID() != newStat.UID() ||
oldStat.GID() != newStat.GID() ||
oldStat.Rdev() != newStat.Rdev() ||
// Don't look at size for dirs, its not a good measure of change
(oldStat.Mode()&syscall.S_IFDIR != syscall.S_IFDIR &&
(!sameFsTimeSpec(oldStat.Mtim(), newStat.Mtim()) || (oldStat.Size() != newStat.Size()))) {
return true
}
return false
}
func (info *FileInfo) isDir() bool {
return info.parent == nil || info.stat.Mode()&syscall.S_IFDIR != 0
}
func getIno(fi os.FileInfo) uint64 {
return uint64(fi.Sys().(*syscall.Stat_t).Ino)
}
func hasHardlinks(fi os.FileInfo) bool {
return fi.Sys().(*syscall.Stat_t).Nlink > 1
}

View File

@@ -0,0 +1,30 @@
package archive
import (
"os"
"github.com/hyperhq/hypercli/pkg/system"
)
func statDifferent(oldStat *system.StatT, newStat *system.StatT) bool {
// Don't look at size for dirs, its not a good measure of change
if oldStat.ModTime() != newStat.ModTime() ||
oldStat.Mode() != newStat.Mode() ||
oldStat.Size() != newStat.Size() && !oldStat.IsDir() {
return true
}
return false
}
func (info *FileInfo) isDir() bool {
return info.parent == nil || info.stat.IsDir()
}
func getIno(fi os.FileInfo) (inode uint64) {
return
}
func hasHardlinks(fi os.FileInfo) bool {
return false
}

458
vendor/github.com/hyperhq/hypercli/pkg/archive/copy.go generated vendored Normal file
View File

@@ -0,0 +1,458 @@
package archive
import (
"archive/tar"
"errors"
"io"
"io/ioutil"
"os"
"path/filepath"
"strings"
"github.com/Sirupsen/logrus"
"github.com/hyperhq/hypercli/pkg/system"
)
// Errors used or returned by this file.
var (
ErrNotDirectory = errors.New("not a directory")
ErrDirNotExists = errors.New("no such directory")
ErrCannotCopyDir = errors.New("cannot copy directory")
ErrInvalidCopySource = errors.New("invalid copy source content")
)
// PreserveTrailingDotOrSeparator returns the given cleaned path (after
// processing using any utility functions from the path or filepath stdlib
// packages) and appends a trailing `/.` or `/` if its corresponding original
// path (from before being processed by utility functions from the path or
// filepath stdlib packages) ends with a trailing `/.` or `/`. If the cleaned
// path already ends in a `.` path segment, then another is not added. If the
// clean path already ends in a path separator, then another is not added.
func PreserveTrailingDotOrSeparator(cleanedPath, originalPath string) string {
// Ensure paths are in platform semantics
cleanedPath = normalizePath(cleanedPath)
originalPath = normalizePath(originalPath)
if !specifiesCurrentDir(cleanedPath) && specifiesCurrentDir(originalPath) {
if !hasTrailingPathSeparator(cleanedPath) {
// Add a separator if it doesn't already end with one (a cleaned
// path would only end in a separator if it is the root).
cleanedPath += string(filepath.Separator)
}
cleanedPath += "."
}
if !hasTrailingPathSeparator(cleanedPath) && hasTrailingPathSeparator(originalPath) {
cleanedPath += string(filepath.Separator)
}
return cleanedPath
}
// assertsDirectory returns whether the given path is
// asserted to be a directory, i.e., the path ends with
// a trailing '/' or `/.`, assuming a path separator of `/`.
func assertsDirectory(path string) bool {
return hasTrailingPathSeparator(path) || specifiesCurrentDir(path)
}
// hasTrailingPathSeparator returns whether the given
// path ends with the system's path separator character.
func hasTrailingPathSeparator(path string) bool {
return len(path) > 0 && os.IsPathSeparator(path[len(path)-1])
}
// specifiesCurrentDir returns whether the given path specifies
// a "current directory", i.e., the last path segment is `.`.
func specifiesCurrentDir(path string) bool {
return filepath.Base(path) == "."
}
// SplitPathDirEntry splits the given path between its directory name and its
// basename by first cleaning the path but preserves a trailing "." if the
// original path specified the current directory.
func SplitPathDirEntry(path string) (dir, base string) {
cleanedPath := filepath.Clean(normalizePath(path))
if specifiesCurrentDir(path) {
cleanedPath += string(filepath.Separator) + "."
}
return filepath.Dir(cleanedPath), filepath.Base(cleanedPath)
}
// TarResource archives the resource described by the given CopyInfo to a Tar
// archive. A non-nil error is returned if sourcePath does not exist or is
// asserted to be a directory but exists as another type of file.
//
// This function acts as a convenient wrapper around TarWithOptions, which
// requires a directory as the source path. TarResource accepts either a
// directory or a file path and correctly sets the Tar options.
func TarResource(sourceInfo CopyInfo) (content Archive, err error) {
return TarResourceRebase(sourceInfo.Path, sourceInfo.RebaseName)
}
// TarResourceRebase is like TarResource but renames the first path element of
// items in the resulting tar archive to match the given rebaseName if not "".
func TarResourceRebase(sourcePath, rebaseName string) (content Archive, err error) {
sourcePath = normalizePath(sourcePath)
if _, err = os.Lstat(sourcePath); err != nil {
// Catches the case where the source does not exist or is not a
// directory if asserted to be a directory, as this also causes an
// error.
return
}
// Separate the source path between it's directory and
// the entry in that directory which we are archiving.
sourceDir, sourceBase := SplitPathDirEntry(sourcePath)
filter := []string{sourceBase}
logrus.Debugf("copying %q from %q", sourceBase, sourceDir)
return TarWithOptions(sourceDir, &TarOptions{
Compression: Uncompressed,
IncludeFiles: filter,
IncludeSourceDir: true,
RebaseNames: map[string]string{
sourceBase: rebaseName,
},
})
}
// CopyInfo holds basic info about the source
// or destination path of a copy operation.
type CopyInfo struct {
Path string
Exists bool
IsDir bool
RebaseName string
}
// CopyInfoSourcePath stats the given path to create a CopyInfo
// struct representing that resource for the source of an archive copy
// operation. The given path should be an absolute local path. A source path
// has all symlinks evaluated that appear before the last path separator ("/"
// on Unix). As it is to be a copy source, the path must exist.
func CopyInfoSourcePath(path string, followLink bool) (CopyInfo, error) {
// normalize the file path and then evaluate the symbol link
// we will use the target file instead of the symbol link if
// followLink is set
path = normalizePath(path)
resolvedPath, rebaseName, err := ResolveHostSourcePath(path, followLink)
if err != nil {
return CopyInfo{}, err
}
stat, err := os.Lstat(resolvedPath)
if err != nil {
return CopyInfo{}, err
}
return CopyInfo{
Path: resolvedPath,
Exists: true,
IsDir: stat.IsDir(),
RebaseName: rebaseName,
}, nil
}
// CopyInfoDestinationPath stats the given path to create a CopyInfo
// struct representing that resource for the destination of an archive copy
// operation. The given path should be an absolute local path.
func CopyInfoDestinationPath(path string) (info CopyInfo, err error) {
maxSymlinkIter := 10 // filepath.EvalSymlinks uses 255, but 10 already seems like a lot.
path = normalizePath(path)
originalPath := path
stat, err := os.Lstat(path)
if err == nil && stat.Mode()&os.ModeSymlink == 0 {
// The path exists and is not a symlink.
return CopyInfo{
Path: path,
Exists: true,
IsDir: stat.IsDir(),
}, nil
}
// While the path is a symlink.
for n := 0; err == nil && stat.Mode()&os.ModeSymlink != 0; n++ {
if n > maxSymlinkIter {
// Don't follow symlinks more than this arbitrary number of times.
return CopyInfo{}, errors.New("too many symlinks in " + originalPath)
}
// The path is a symbolic link. We need to evaluate it so that the
// destination of the copy operation is the link target and not the
// link itself. This is notably different than CopyInfoSourcePath which
// only evaluates symlinks before the last appearing path separator.
// Also note that it is okay if the last path element is a broken
// symlink as the copy operation should create the target.
var linkTarget string
linkTarget, err = os.Readlink(path)
if err != nil {
return CopyInfo{}, err
}
if !system.IsAbs(linkTarget) {
// Join with the parent directory.
dstParent, _ := SplitPathDirEntry(path)
linkTarget = filepath.Join(dstParent, linkTarget)
}
path = linkTarget
stat, err = os.Lstat(path)
}
if err != nil {
// It's okay if the destination path doesn't exist. We can still
// continue the copy operation if the parent directory exists.
if !os.IsNotExist(err) {
return CopyInfo{}, err
}
// Ensure destination parent dir exists.
dstParent, _ := SplitPathDirEntry(path)
parentDirStat, err := os.Lstat(dstParent)
if err != nil {
return CopyInfo{}, err
}
if !parentDirStat.IsDir() {
return CopyInfo{}, ErrNotDirectory
}
return CopyInfo{Path: path}, nil
}
// The path exists after resolving symlinks.
return CopyInfo{
Path: path,
Exists: true,
IsDir: stat.IsDir(),
}, nil
}
// PrepareArchiveCopy prepares the given srcContent archive, which should
// contain the archived resource described by srcInfo, to the destination
// described by dstInfo. Returns the possibly modified content archive along
// with the path to the destination directory which it should be extracted to.
func PrepareArchiveCopy(srcContent Reader, srcInfo, dstInfo CopyInfo) (dstDir string, content Archive, err error) {
// Ensure in platform semantics
srcInfo.Path = normalizePath(srcInfo.Path)
dstInfo.Path = normalizePath(dstInfo.Path)
// Separate the destination path between its directory and base
// components in case the source archive contents need to be rebased.
dstDir, dstBase := SplitPathDirEntry(dstInfo.Path)
_, srcBase := SplitPathDirEntry(srcInfo.Path)
switch {
case dstInfo.Exists && dstInfo.IsDir:
// The destination exists as a directory. No alteration
// to srcContent is needed as its contents can be
// simply extracted to the destination directory.
return dstInfo.Path, ioutil.NopCloser(srcContent), nil
case dstInfo.Exists && srcInfo.IsDir:
// The destination exists as some type of file and the source
// content is a directory. This is an error condition since
// you cannot copy a directory to an existing file location.
return "", nil, ErrCannotCopyDir
case dstInfo.Exists:
// The destination exists as some type of file and the source content
// is also a file. The source content entry will have to be renamed to
// have a basename which matches the destination path's basename.
if len(srcInfo.RebaseName) != 0 {
srcBase = srcInfo.RebaseName
}
return dstDir, RebaseArchiveEntries(srcContent, srcBase, dstBase), nil
case srcInfo.IsDir:
// The destination does not exist and the source content is an archive
// of a directory. The archive should be extracted to the parent of
// the destination path instead, and when it is, the directory that is
// created as a result should take the name of the destination path.
// The source content entries will have to be renamed to have a
// basename which matches the destination path's basename.
if len(srcInfo.RebaseName) != 0 {
srcBase = srcInfo.RebaseName
}
return dstDir, RebaseArchiveEntries(srcContent, srcBase, dstBase), nil
case assertsDirectory(dstInfo.Path):
// The destination does not exist and is asserted to be created as a
// directory, but the source content is not a directory. This is an
// error condition since you cannot create a directory from a file
// source.
return "", nil, ErrDirNotExists
default:
// The last remaining case is when the destination does not exist, is
// not asserted to be a directory, and the source content is not an
// archive of a directory. It this case, the destination file will need
// to be created when the archive is extracted and the source content
// entry will have to be renamed to have a basename which matches the
// destination path's basename.
if len(srcInfo.RebaseName) != 0 {
srcBase = srcInfo.RebaseName
}
return dstDir, RebaseArchiveEntries(srcContent, srcBase, dstBase), nil
}
}
// RebaseArchiveEntries rewrites the given srcContent archive replacing
// an occurrence of oldBase with newBase at the beginning of entry names.
func RebaseArchiveEntries(srcContent Reader, oldBase, newBase string) Archive {
if oldBase == string(os.PathSeparator) {
// If oldBase specifies the root directory, use an empty string as
// oldBase instead so that newBase doesn't replace the path separator
// that all paths will start with.
oldBase = ""
}
rebased, w := io.Pipe()
go func() {
srcTar := tar.NewReader(srcContent)
rebasedTar := tar.NewWriter(w)
for {
hdr, err := srcTar.Next()
if err == io.EOF {
// Signals end of archive.
rebasedTar.Close()
w.Close()
return
}
if err != nil {
w.CloseWithError(err)
return
}
hdr.Name = strings.Replace(hdr.Name, oldBase, newBase, 1)
if err = rebasedTar.WriteHeader(hdr); err != nil {
w.CloseWithError(err)
return
}
if _, err = io.Copy(rebasedTar, srcTar); err != nil {
w.CloseWithError(err)
return
}
}
}()
return rebased
}
// CopyResource performs an archive copy from the given source path to the
// given destination path. The source path MUST exist and the destination
// path's parent directory must exist.
func CopyResource(srcPath, dstPath string, followLink bool) error {
var (
srcInfo CopyInfo
err error
)
// Ensure in platform semantics
srcPath = normalizePath(srcPath)
dstPath = normalizePath(dstPath)
// Clean the source and destination paths.
srcPath = PreserveTrailingDotOrSeparator(filepath.Clean(srcPath), srcPath)
dstPath = PreserveTrailingDotOrSeparator(filepath.Clean(dstPath), dstPath)
if srcInfo, err = CopyInfoSourcePath(srcPath, followLink); err != nil {
return err
}
content, err := TarResource(srcInfo)
if err != nil {
return err
}
defer content.Close()
return CopyTo(content, srcInfo, dstPath)
}
// CopyTo handles extracting the given content whose
// entries should be sourced from srcInfo to dstPath.
func CopyTo(content Reader, srcInfo CopyInfo, dstPath string) error {
// The destination path need not exist, but CopyInfoDestinationPath will
// ensure that at least the parent directory exists.
dstInfo, err := CopyInfoDestinationPath(normalizePath(dstPath))
if err != nil {
return err
}
dstDir, copyArchive, err := PrepareArchiveCopy(content, srcInfo, dstInfo)
if err != nil {
return err
}
defer copyArchive.Close()
options := &TarOptions{
NoLchown: true,
NoOverwriteDirNonDir: true,
}
return Untar(copyArchive, dstDir, options)
}
// ResolveHostSourcePath decides real path need to be copied with parameters such as
// whether to follow symbol link or not, if followLink is true, resolvedPath will return
// link target of any symbol link file, else it will only resolve symlink of directory
// but return symbol link file itself without resolving.
func ResolveHostSourcePath(path string, followLink bool) (resolvedPath, rebaseName string, err error) {
if followLink {
resolvedPath, err = filepath.EvalSymlinks(path)
if err != nil {
return
}
resolvedPath, rebaseName = GetRebaseName(path, resolvedPath)
} else {
dirPath, basePath := filepath.Split(path)
// if not follow symbol link, then resolve symbol link of parent dir
var resolvedDirPath string
resolvedDirPath, err = filepath.EvalSymlinks(dirPath)
if err != nil {
return
}
// resolvedDirPath will have been cleaned (no trailing path separators) so
// we can manually join it with the base path element.
resolvedPath = resolvedDirPath + string(filepath.Separator) + basePath
if hasTrailingPathSeparator(path) && filepath.Base(path) != filepath.Base(resolvedPath) {
rebaseName = filepath.Base(path)
}
}
return resolvedPath, rebaseName, nil
}
// GetRebaseName normalizes and compares path and resolvedPath,
// return completed resolved path and rebased file name
func GetRebaseName(path, resolvedPath string) (string, string) {
// linkTarget will have been cleaned (no trailing path separators and dot) so
// we can manually join it with them
var rebaseName string
if specifiesCurrentDir(path) && !specifiesCurrentDir(resolvedPath) {
resolvedPath += string(filepath.Separator) + "."
}
if hasTrailingPathSeparator(path) && !hasTrailingPathSeparator(resolvedPath) {
resolvedPath += string(filepath.Separator)
}
if filepath.Base(path) != filepath.Base(resolvedPath) {
// In the case where the path had a trailing separator and a symlink
// evaluation has changed the last path component, we will need to
// rebase the name in the archive that is being copied to match the
// originally requested name.
rebaseName = filepath.Base(path)
}
return resolvedPath, rebaseName
}

View File

@@ -0,0 +1,974 @@
package archive
import (
"bytes"
"crypto/sha256"
"encoding/hex"
"fmt"
"io"
"io/ioutil"
"os"
"path/filepath"
"strings"
"testing"
)
func removeAllPaths(paths ...string) {
for _, path := range paths {
os.RemoveAll(path)
}
}
func getTestTempDirs(t *testing.T) (tmpDirA, tmpDirB string) {
var err error
if tmpDirA, err = ioutil.TempDir("", "archive-copy-test"); err != nil {
t.Fatal(err)
}
if tmpDirB, err = ioutil.TempDir("", "archive-copy-test"); err != nil {
t.Fatal(err)
}
return
}
func isNotDir(err error) bool {
return strings.Contains(err.Error(), "not a directory")
}
func joinTrailingSep(pathElements ...string) string {
joined := filepath.Join(pathElements...)
return fmt.Sprintf("%s%c", joined, filepath.Separator)
}
func fileContentsEqual(t *testing.T, filenameA, filenameB string) (err error) {
t.Logf("checking for equal file contents: %q and %q\n", filenameA, filenameB)
fileA, err := os.Open(filenameA)
if err != nil {
return
}
defer fileA.Close()
fileB, err := os.Open(filenameB)
if err != nil {
return
}
defer fileB.Close()
hasher := sha256.New()
if _, err = io.Copy(hasher, fileA); err != nil {
return
}
hashA := hasher.Sum(nil)
hasher.Reset()
if _, err = io.Copy(hasher, fileB); err != nil {
return
}
hashB := hasher.Sum(nil)
if !bytes.Equal(hashA, hashB) {
err = fmt.Errorf("file content hashes not equal - expected %s, got %s", hex.EncodeToString(hashA), hex.EncodeToString(hashB))
}
return
}
func dirContentsEqual(t *testing.T, newDir, oldDir string) (err error) {
t.Logf("checking for equal directory contents: %q and %q\n", newDir, oldDir)
var changes []Change
if changes, err = ChangesDirs(newDir, oldDir); err != nil {
return
}
if len(changes) != 0 {
err = fmt.Errorf("expected no changes between directories, but got: %v", changes)
}
return
}
func logDirContents(t *testing.T, dirPath string) {
logWalkedPaths := filepath.WalkFunc(func(path string, info os.FileInfo, err error) error {
if err != nil {
t.Errorf("stat error for path %q: %s", path, err)
return nil
}
if info.IsDir() {
path = joinTrailingSep(path)
}
t.Logf("\t%s", path)
return nil
})
t.Logf("logging directory contents: %q", dirPath)
if err := filepath.Walk(dirPath, logWalkedPaths); err != nil {
t.Fatal(err)
}
}
func testCopyHelper(t *testing.T, srcPath, dstPath string) (err error) {
t.Logf("copying from %q to %q (not follow symbol link)", srcPath, dstPath)
return CopyResource(srcPath, dstPath, false)
}
func testCopyHelperFSym(t *testing.T, srcPath, dstPath string) (err error) {
t.Logf("copying from %q to %q (follow symbol link)", srcPath, dstPath)
return CopyResource(srcPath, dstPath, true)
}
// Basic assumptions about SRC and DST:
// 1. SRC must exist.
// 2. If SRC ends with a trailing separator, it must be a directory.
// 3. DST parent directory must exist.
// 4. If DST exists as a file, it must not end with a trailing separator.
// First get these easy error cases out of the way.
// Test for error when SRC does not exist.
func TestCopyErrSrcNotExists(t *testing.T) {
tmpDirA, tmpDirB := getTestTempDirs(t)
defer removeAllPaths(tmpDirA, tmpDirB)
if _, err := CopyInfoSourcePath(filepath.Join(tmpDirA, "file1"), false); !os.IsNotExist(err) {
t.Fatalf("expected IsNotExist error, but got %T: %s", err, err)
}
}
// Test for error when SRC ends in a trailing
// path separator but it exists as a file.
func TestCopyErrSrcNotDir(t *testing.T) {
tmpDirA, tmpDirB := getTestTempDirs(t)
defer removeAllPaths(tmpDirA, tmpDirB)
// Load A with some sample files and directories.
createSampleDir(t, tmpDirA)
if _, err := CopyInfoSourcePath(joinTrailingSep(tmpDirA, "file1"), false); !isNotDir(err) {
t.Fatalf("expected IsNotDir error, but got %T: %s", err, err)
}
}
// Test for error when SRC is a valid file or directory,
// but the DST parent directory does not exist.
func TestCopyErrDstParentNotExists(t *testing.T) {
tmpDirA, tmpDirB := getTestTempDirs(t)
defer removeAllPaths(tmpDirA, tmpDirB)
// Load A with some sample files and directories.
createSampleDir(t, tmpDirA)
srcInfo := CopyInfo{Path: filepath.Join(tmpDirA, "file1"), Exists: true, IsDir: false}
// Try with a file source.
content, err := TarResource(srcInfo)
if err != nil {
t.Fatalf("unexpected error %T: %s", err, err)
}
defer content.Close()
// Copy to a file whose parent does not exist.
if err = CopyTo(content, srcInfo, filepath.Join(tmpDirB, "fakeParentDir", "file1")); err == nil {
t.Fatal("expected IsNotExist error, but got nil instead")
}
if !os.IsNotExist(err) {
t.Fatalf("expected IsNotExist error, but got %T: %s", err, err)
}
// Try with a directory source.
srcInfo = CopyInfo{Path: filepath.Join(tmpDirA, "dir1"), Exists: true, IsDir: true}
content, err = TarResource(srcInfo)
if err != nil {
t.Fatalf("unexpected error %T: %s", err, err)
}
defer content.Close()
// Copy to a directory whose parent does not exist.
if err = CopyTo(content, srcInfo, joinTrailingSep(tmpDirB, "fakeParentDir", "fakeDstDir")); err == nil {
t.Fatal("expected IsNotExist error, but got nil instead")
}
if !os.IsNotExist(err) {
t.Fatalf("expected IsNotExist error, but got %T: %s", err, err)
}
}
// Test for error when DST ends in a trailing
// path separator but exists as a file.
func TestCopyErrDstNotDir(t *testing.T) {
tmpDirA, tmpDirB := getTestTempDirs(t)
defer removeAllPaths(tmpDirA, tmpDirB)
// Load A and B with some sample files and directories.
createSampleDir(t, tmpDirA)
createSampleDir(t, tmpDirB)
// Try with a file source.
srcInfo := CopyInfo{Path: filepath.Join(tmpDirA, "file1"), Exists: true, IsDir: false}
content, err := TarResource(srcInfo)
if err != nil {
t.Fatalf("unexpected error %T: %s", err, err)
}
defer content.Close()
if err = CopyTo(content, srcInfo, joinTrailingSep(tmpDirB, "file1")); err == nil {
t.Fatal("expected IsNotDir error, but got nil instead")
}
if !isNotDir(err) {
t.Fatalf("expected IsNotDir error, but got %T: %s", err, err)
}
// Try with a directory source.
srcInfo = CopyInfo{Path: filepath.Join(tmpDirA, "dir1"), Exists: true, IsDir: true}
content, err = TarResource(srcInfo)
if err != nil {
t.Fatalf("unexpected error %T: %s", err, err)
}
defer content.Close()
if err = CopyTo(content, srcInfo, joinTrailingSep(tmpDirB, "file1")); err == nil {
t.Fatal("expected IsNotDir error, but got nil instead")
}
if !isNotDir(err) {
t.Fatalf("expected IsNotDir error, but got %T: %s", err, err)
}
}
// Possibilities are reduced to the remaining 10 cases:
//
// case | srcIsDir | onlyDirContents | dstExists | dstIsDir | dstTrSep | action
// ===================================================================================================
// A | no | - | no | - | no | create file
// B | no | - | no | - | yes | error
// C | no | - | yes | no | - | overwrite file
// D | no | - | yes | yes | - | create file in dst dir
// E | yes | no | no | - | - | create dir, copy contents
// F | yes | no | yes | no | - | error
// G | yes | no | yes | yes | - | copy dir and contents
// H | yes | yes | no | - | - | create dir, copy contents
// I | yes | yes | yes | no | - | error
// J | yes | yes | yes | yes | - | copy dir contents
//
// A. SRC specifies a file and DST (no trailing path separator) doesn't
// exist. This should create a file with the name DST and copy the
// contents of the source file into it.
func TestCopyCaseA(t *testing.T) {
tmpDirA, tmpDirB := getTestTempDirs(t)
defer removeAllPaths(tmpDirA, tmpDirB)
// Load A with some sample files and directories.
createSampleDir(t, tmpDirA)
srcPath := filepath.Join(tmpDirA, "file1")
dstPath := filepath.Join(tmpDirB, "itWorks.txt")
var err error
if err = testCopyHelper(t, srcPath, dstPath); err != nil {
t.Fatalf("unexpected error %T: %s", err, err)
}
if err = fileContentsEqual(t, srcPath, dstPath); err != nil {
t.Fatal(err)
}
os.Remove(dstPath)
symlinkPath := filepath.Join(tmpDirA, "symlink3")
symlinkPath1 := filepath.Join(tmpDirA, "symlink4")
linkTarget := filepath.Join(tmpDirA, "file1")
if err = testCopyHelperFSym(t, symlinkPath, dstPath); err != nil {
t.Fatalf("unexpected error %T: %s", err, err)
}
if err = fileContentsEqual(t, linkTarget, dstPath); err != nil {
t.Fatal(err)
}
os.Remove(dstPath)
if err = testCopyHelperFSym(t, symlinkPath1, dstPath); err != nil {
t.Fatalf("unexpected error %T: %s", err, err)
}
if err = fileContentsEqual(t, linkTarget, dstPath); err != nil {
t.Fatal(err)
}
}
// B. SRC specifies a file and DST (with trailing path separator) doesn't
// exist. This should cause an error because the copy operation cannot
// create a directory when copying a single file.
func TestCopyCaseB(t *testing.T) {
tmpDirA, tmpDirB := getTestTempDirs(t)
defer removeAllPaths(tmpDirA, tmpDirB)
// Load A with some sample files and directories.
createSampleDir(t, tmpDirA)
srcPath := filepath.Join(tmpDirA, "file1")
dstDir := joinTrailingSep(tmpDirB, "testDir")
var err error
if err = testCopyHelper(t, srcPath, dstDir); err == nil {
t.Fatal("expected ErrDirNotExists error, but got nil instead")
}
if err != ErrDirNotExists {
t.Fatalf("expected ErrDirNotExists error, but got %T: %s", err, err)
}
symlinkPath := filepath.Join(tmpDirA, "symlink3")
if err = testCopyHelperFSym(t, symlinkPath, dstDir); err == nil {
t.Fatal("expected ErrDirNotExists error, but got nil instead")
}
if err != ErrDirNotExists {
t.Fatalf("expected ErrDirNotExists error, but got %T: %s", err, err)
}
}
// C. SRC specifies a file and DST exists as a file. This should overwrite
// the file at DST with the contents of the source file.
func TestCopyCaseC(t *testing.T) {
tmpDirA, tmpDirB := getTestTempDirs(t)
defer removeAllPaths(tmpDirA, tmpDirB)
// Load A and B with some sample files and directories.
createSampleDir(t, tmpDirA)
createSampleDir(t, tmpDirB)
srcPath := filepath.Join(tmpDirA, "file1")
dstPath := filepath.Join(tmpDirB, "file2")
var err error
// Ensure they start out different.
if err = fileContentsEqual(t, srcPath, dstPath); err == nil {
t.Fatal("expected different file contents")
}
if err = testCopyHelper(t, srcPath, dstPath); err != nil {
t.Fatalf("unexpected error %T: %s", err, err)
}
if err = fileContentsEqual(t, srcPath, dstPath); err != nil {
t.Fatal(err)
}
}
// C. Symbol link following version:
// SRC specifies a file and DST exists as a file. This should overwrite
// the file at DST with the contents of the source file.
func TestCopyCaseCFSym(t *testing.T) {
tmpDirA, tmpDirB := getTestTempDirs(t)
defer removeAllPaths(tmpDirA, tmpDirB)
// Load A and B with some sample files and directories.
createSampleDir(t, tmpDirA)
createSampleDir(t, tmpDirB)
symlinkPathBad := filepath.Join(tmpDirA, "symlink1")
symlinkPath := filepath.Join(tmpDirA, "symlink3")
linkTarget := filepath.Join(tmpDirA, "file1")
dstPath := filepath.Join(tmpDirB, "file2")
var err error
// first to test broken link
if err = testCopyHelperFSym(t, symlinkPathBad, dstPath); err == nil {
t.Fatalf("unexpected error %T: %s", err, err)
}
// test symbol link -> symbol link -> target
// Ensure they start out different.
if err = fileContentsEqual(t, linkTarget, dstPath); err == nil {
t.Fatal("expected different file contents")
}
if err = testCopyHelperFSym(t, symlinkPath, dstPath); err != nil {
t.Fatalf("unexpected error %T: %s", err, err)
}
if err = fileContentsEqual(t, linkTarget, dstPath); err != nil {
t.Fatal(err)
}
}
// D. SRC specifies a file and DST exists as a directory. This should place
// a copy of the source file inside it using the basename from SRC. Ensure
// this works whether DST has a trailing path separator or not.
func TestCopyCaseD(t *testing.T) {
tmpDirA, tmpDirB := getTestTempDirs(t)
defer removeAllPaths(tmpDirA, tmpDirB)
// Load A and B with some sample files and directories.
createSampleDir(t, tmpDirA)
createSampleDir(t, tmpDirB)
srcPath := filepath.Join(tmpDirA, "file1")
dstDir := filepath.Join(tmpDirB, "dir1")
dstPath := filepath.Join(dstDir, "file1")
var err error
// Ensure that dstPath doesn't exist.
if _, err = os.Stat(dstPath); !os.IsNotExist(err) {
t.Fatalf("did not expect dstPath %q to exist", dstPath)
}
if err = testCopyHelper(t, srcPath, dstDir); err != nil {
t.Fatalf("unexpected error %T: %s", err, err)
}
if err = fileContentsEqual(t, srcPath, dstPath); err != nil {
t.Fatal(err)
}
// Now try again but using a trailing path separator for dstDir.
if err = os.RemoveAll(dstDir); err != nil {
t.Fatalf("unable to remove dstDir: %s", err)
}
if err = os.MkdirAll(dstDir, os.FileMode(0755)); err != nil {
t.Fatalf("unable to make dstDir: %s", err)
}
dstDir = joinTrailingSep(tmpDirB, "dir1")
if err = testCopyHelper(t, srcPath, dstDir); err != nil {
t.Fatalf("unexpected error %T: %s", err, err)
}
if err = fileContentsEqual(t, srcPath, dstPath); err != nil {
t.Fatal(err)
}
}
// D. Symbol link following version:
// SRC specifies a file and DST exists as a directory. This should place
// a copy of the source file inside it using the basename from SRC. Ensure
// this works whether DST has a trailing path separator or not.
func TestCopyCaseDFSym(t *testing.T) {
tmpDirA, tmpDirB := getTestTempDirs(t)
defer removeAllPaths(tmpDirA, tmpDirB)
// Load A and B with some sample files and directories.
createSampleDir(t, tmpDirA)
createSampleDir(t, tmpDirB)
srcPath := filepath.Join(tmpDirA, "symlink4")
linkTarget := filepath.Join(tmpDirA, "file1")
dstDir := filepath.Join(tmpDirB, "dir1")
dstPath := filepath.Join(dstDir, "symlink4")
var err error
// Ensure that dstPath doesn't exist.
if _, err = os.Stat(dstPath); !os.IsNotExist(err) {
t.Fatalf("did not expect dstPath %q to exist", dstPath)
}
if err = testCopyHelperFSym(t, srcPath, dstDir); err != nil {
t.Fatalf("unexpected error %T: %s", err, err)
}
if err = fileContentsEqual(t, linkTarget, dstPath); err != nil {
t.Fatal(err)
}
// Now try again but using a trailing path separator for dstDir.
if err = os.RemoveAll(dstDir); err != nil {
t.Fatalf("unable to remove dstDir: %s", err)
}
if err = os.MkdirAll(dstDir, os.FileMode(0755)); err != nil {
t.Fatalf("unable to make dstDir: %s", err)
}
dstDir = joinTrailingSep(tmpDirB, "dir1")
if err = testCopyHelperFSym(t, srcPath, dstDir); err != nil {
t.Fatalf("unexpected error %T: %s", err, err)
}
if err = fileContentsEqual(t, linkTarget, dstPath); err != nil {
t.Fatal(err)
}
}
// E. SRC specifies a directory and DST does not exist. This should create a
// directory at DST and copy the contents of the SRC directory into the DST
// directory. Ensure this works whether DST has a trailing path separator or
// not.
func TestCopyCaseE(t *testing.T) {
tmpDirA, tmpDirB := getTestTempDirs(t)
defer removeAllPaths(tmpDirA, tmpDirB)
// Load A with some sample files and directories.
createSampleDir(t, tmpDirA)
srcDir := filepath.Join(tmpDirA, "dir1")
dstDir := filepath.Join(tmpDirB, "testDir")
var err error
if err = testCopyHelper(t, srcDir, dstDir); err != nil {
t.Fatalf("unexpected error %T: %s", err, err)
}
if err = dirContentsEqual(t, dstDir, srcDir); err != nil {
t.Log("dir contents not equal")
logDirContents(t, tmpDirA)
logDirContents(t, tmpDirB)
t.Fatal(err)
}
// Now try again but using a trailing path separator for dstDir.
if err = os.RemoveAll(dstDir); err != nil {
t.Fatalf("unable to remove dstDir: %s", err)
}
dstDir = joinTrailingSep(tmpDirB, "testDir")
if err = testCopyHelper(t, srcDir, dstDir); err != nil {
t.Fatalf("unexpected error %T: %s", err, err)
}
if err = dirContentsEqual(t, dstDir, srcDir); err != nil {
t.Fatal(err)
}
}
// E. Symbol link following version:
// SRC specifies a directory and DST does not exist. This should create a
// directory at DST and copy the contents of the SRC directory into the DST
// directory. Ensure this works whether DST has a trailing path separator or
// not.
func TestCopyCaseEFSym(t *testing.T) {
tmpDirA, tmpDirB := getTestTempDirs(t)
defer removeAllPaths(tmpDirA, tmpDirB)
// Load A with some sample files and directories.
createSampleDir(t, tmpDirA)
srcDir := filepath.Join(tmpDirA, "dirSymlink")
linkTarget := filepath.Join(tmpDirA, "dir1")
dstDir := filepath.Join(tmpDirB, "testDir")
var err error
if err = testCopyHelperFSym(t, srcDir, dstDir); err != nil {
t.Fatalf("unexpected error %T: %s", err, err)
}
if err = dirContentsEqual(t, dstDir, linkTarget); err != nil {
t.Log("dir contents not equal")
logDirContents(t, tmpDirA)
logDirContents(t, tmpDirB)
t.Fatal(err)
}
// Now try again but using a trailing path separator for dstDir.
if err = os.RemoveAll(dstDir); err != nil {
t.Fatalf("unable to remove dstDir: %s", err)
}
dstDir = joinTrailingSep(tmpDirB, "testDir")
if err = testCopyHelperFSym(t, srcDir, dstDir); err != nil {
t.Fatalf("unexpected error %T: %s", err, err)
}
if err = dirContentsEqual(t, dstDir, linkTarget); err != nil {
t.Fatal(err)
}
}
// F. SRC specifies a directory and DST exists as a file. This should cause an
// error as it is not possible to overwrite a file with a directory.
func TestCopyCaseF(t *testing.T) {
tmpDirA, tmpDirB := getTestTempDirs(t)
defer removeAllPaths(tmpDirA, tmpDirB)
// Load A and B with some sample files and directories.
createSampleDir(t, tmpDirA)
createSampleDir(t, tmpDirB)
srcDir := filepath.Join(tmpDirA, "dir1")
symSrcDir := filepath.Join(tmpDirA, "dirSymlink")
dstFile := filepath.Join(tmpDirB, "file1")
var err error
if err = testCopyHelper(t, srcDir, dstFile); err == nil {
t.Fatal("expected ErrCannotCopyDir error, but got nil instead")
}
if err != ErrCannotCopyDir {
t.Fatalf("expected ErrCannotCopyDir error, but got %T: %s", err, err)
}
// now test with symbol link
if err = testCopyHelperFSym(t, symSrcDir, dstFile); err == nil {
t.Fatal("expected ErrCannotCopyDir error, but got nil instead")
}
if err != ErrCannotCopyDir {
t.Fatalf("expected ErrCannotCopyDir error, but got %T: %s", err, err)
}
}
// G. SRC specifies a directory and DST exists as a directory. This should copy
// the SRC directory and all its contents to the DST directory. Ensure this
// works whether DST has a trailing path separator or not.
func TestCopyCaseG(t *testing.T) {
tmpDirA, tmpDirB := getTestTempDirs(t)
defer removeAllPaths(tmpDirA, tmpDirB)
// Load A and B with some sample files and directories.
createSampleDir(t, tmpDirA)
createSampleDir(t, tmpDirB)
srcDir := filepath.Join(tmpDirA, "dir1")
dstDir := filepath.Join(tmpDirB, "dir2")
resultDir := filepath.Join(dstDir, "dir1")
var err error
if err = testCopyHelper(t, srcDir, dstDir); err != nil {
t.Fatalf("unexpected error %T: %s", err, err)
}
if err = dirContentsEqual(t, resultDir, srcDir); err != nil {
t.Fatal(err)
}
// Now try again but using a trailing path separator for dstDir.
if err = os.RemoveAll(dstDir); err != nil {
t.Fatalf("unable to remove dstDir: %s", err)
}
if err = os.MkdirAll(dstDir, os.FileMode(0755)); err != nil {
t.Fatalf("unable to make dstDir: %s", err)
}
dstDir = joinTrailingSep(tmpDirB, "dir2")
if err = testCopyHelper(t, srcDir, dstDir); err != nil {
t.Fatalf("unexpected error %T: %s", err, err)
}
if err = dirContentsEqual(t, resultDir, srcDir); err != nil {
t.Fatal(err)
}
}
// G. Symbol link version:
// SRC specifies a directory and DST exists as a directory. This should copy
// the SRC directory and all its contents to the DST directory. Ensure this
// works whether DST has a trailing path separator or not.
func TestCopyCaseGFSym(t *testing.T) {
tmpDirA, tmpDirB := getTestTempDirs(t)
defer removeAllPaths(tmpDirA, tmpDirB)
// Load A and B with some sample files and directories.
createSampleDir(t, tmpDirA)
createSampleDir(t, tmpDirB)
srcDir := filepath.Join(tmpDirA, "dirSymlink")
linkTarget := filepath.Join(tmpDirA, "dir1")
dstDir := filepath.Join(tmpDirB, "dir2")
resultDir := filepath.Join(dstDir, "dirSymlink")
var err error
if err = testCopyHelperFSym(t, srcDir, dstDir); err != nil {
t.Fatalf("unexpected error %T: %s", err, err)
}
if err = dirContentsEqual(t, resultDir, linkTarget); err != nil {
t.Fatal(err)
}
// Now try again but using a trailing path separator for dstDir.
if err = os.RemoveAll(dstDir); err != nil {
t.Fatalf("unable to remove dstDir: %s", err)
}
if err = os.MkdirAll(dstDir, os.FileMode(0755)); err != nil {
t.Fatalf("unable to make dstDir: %s", err)
}
dstDir = joinTrailingSep(tmpDirB, "dir2")
if err = testCopyHelperFSym(t, srcDir, dstDir); err != nil {
t.Fatalf("unexpected error %T: %s", err, err)
}
if err = dirContentsEqual(t, resultDir, linkTarget); err != nil {
t.Fatal(err)
}
}
// H. SRC specifies a directory's contents only and DST does not exist. This
// should create a directory at DST and copy the contents of the SRC
// directory (but not the directory itself) into the DST directory. Ensure
// this works whether DST has a trailing path separator or not.
func TestCopyCaseH(t *testing.T) {
tmpDirA, tmpDirB := getTestTempDirs(t)
defer removeAllPaths(tmpDirA, tmpDirB)
// Load A with some sample files and directories.
createSampleDir(t, tmpDirA)
srcDir := joinTrailingSep(tmpDirA, "dir1") + "."
dstDir := filepath.Join(tmpDirB, "testDir")
var err error
if err = testCopyHelper(t, srcDir, dstDir); err != nil {
t.Fatalf("unexpected error %T: %s", err, err)
}
if err = dirContentsEqual(t, dstDir, srcDir); err != nil {
t.Log("dir contents not equal")
logDirContents(t, tmpDirA)
logDirContents(t, tmpDirB)
t.Fatal(err)
}
// Now try again but using a trailing path separator for dstDir.
if err = os.RemoveAll(dstDir); err != nil {
t.Fatalf("unable to remove dstDir: %s", err)
}
dstDir = joinTrailingSep(tmpDirB, "testDir")
if err = testCopyHelper(t, srcDir, dstDir); err != nil {
t.Fatalf("unexpected error %T: %s", err, err)
}
if err = dirContentsEqual(t, dstDir, srcDir); err != nil {
t.Log("dir contents not equal")
logDirContents(t, tmpDirA)
logDirContents(t, tmpDirB)
t.Fatal(err)
}
}
// H. Symbol link following version:
// SRC specifies a directory's contents only and DST does not exist. This
// should create a directory at DST and copy the contents of the SRC
// directory (but not the directory itself) into the DST directory. Ensure
// this works whether DST has a trailing path separator or not.
func TestCopyCaseHFSym(t *testing.T) {
tmpDirA, tmpDirB := getTestTempDirs(t)
defer removeAllPaths(tmpDirA, tmpDirB)
// Load A with some sample files and directories.
createSampleDir(t, tmpDirA)
srcDir := joinTrailingSep(tmpDirA, "dirSymlink") + "."
linkTarget := filepath.Join(tmpDirA, "dir1")
dstDir := filepath.Join(tmpDirB, "testDir")
var err error
if err = testCopyHelperFSym(t, srcDir, dstDir); err != nil {
t.Fatalf("unexpected error %T: %s", err, err)
}
if err = dirContentsEqual(t, dstDir, linkTarget); err != nil {
t.Log("dir contents not equal")
logDirContents(t, tmpDirA)
logDirContents(t, tmpDirB)
t.Fatal(err)
}
// Now try again but using a trailing path separator for dstDir.
if err = os.RemoveAll(dstDir); err != nil {
t.Fatalf("unable to remove dstDir: %s", err)
}
dstDir = joinTrailingSep(tmpDirB, "testDir")
if err = testCopyHelperFSym(t, srcDir, dstDir); err != nil {
t.Fatalf("unexpected error %T: %s", err, err)
}
if err = dirContentsEqual(t, dstDir, linkTarget); err != nil {
t.Log("dir contents not equal")
logDirContents(t, tmpDirA)
logDirContents(t, tmpDirB)
t.Fatal(err)
}
}
// I. SRC specifies a directory's contents only and DST exists as a file. This
// should cause an error as it is not possible to overwrite a file with a
// directory.
func TestCopyCaseI(t *testing.T) {
tmpDirA, tmpDirB := getTestTempDirs(t)
defer removeAllPaths(tmpDirA, tmpDirB)
// Load A and B with some sample files and directories.
createSampleDir(t, tmpDirA)
createSampleDir(t, tmpDirB)
srcDir := joinTrailingSep(tmpDirA, "dir1") + "."
symSrcDir := filepath.Join(tmpDirB, "dirSymlink")
dstFile := filepath.Join(tmpDirB, "file1")
var err error
if err = testCopyHelper(t, srcDir, dstFile); err == nil {
t.Fatal("expected ErrCannotCopyDir error, but got nil instead")
}
if err != ErrCannotCopyDir {
t.Fatalf("expected ErrCannotCopyDir error, but got %T: %s", err, err)
}
// now try with symbol link of dir
if err = testCopyHelperFSym(t, symSrcDir, dstFile); err == nil {
t.Fatal("expected ErrCannotCopyDir error, but got nil instead")
}
if err != ErrCannotCopyDir {
t.Fatalf("expected ErrCannotCopyDir error, but got %T: %s", err, err)
}
}
// J. SRC specifies a directory's contents only and DST exists as a directory.
// This should copy the contents of the SRC directory (but not the directory
// itself) into the DST directory. Ensure this works whether DST has a
// trailing path separator or not.
func TestCopyCaseJ(t *testing.T) {
tmpDirA, tmpDirB := getTestTempDirs(t)
defer removeAllPaths(tmpDirA, tmpDirB)
// Load A and B with some sample files and directories.
createSampleDir(t, tmpDirA)
createSampleDir(t, tmpDirB)
srcDir := joinTrailingSep(tmpDirA, "dir1") + "."
dstDir := filepath.Join(tmpDirB, "dir5")
var err error
// first to create an empty dir
if err = os.MkdirAll(dstDir, os.FileMode(0755)); err != nil {
t.Fatalf("unable to make dstDir: %s", err)
}
if err = testCopyHelper(t, srcDir, dstDir); err != nil {
t.Fatalf("unexpected error %T: %s", err, err)
}
if err = dirContentsEqual(t, dstDir, srcDir); err != nil {
t.Fatal(err)
}
// Now try again but using a trailing path separator for dstDir.
if err = os.RemoveAll(dstDir); err != nil {
t.Fatalf("unable to remove dstDir: %s", err)
}
if err = os.MkdirAll(dstDir, os.FileMode(0755)); err != nil {
t.Fatalf("unable to make dstDir: %s", err)
}
dstDir = joinTrailingSep(tmpDirB, "dir5")
if err = testCopyHelper(t, srcDir, dstDir); err != nil {
t.Fatalf("unexpected error %T: %s", err, err)
}
if err = dirContentsEqual(t, dstDir, srcDir); err != nil {
t.Fatal(err)
}
}
// J. Symbol link following version:
// SRC specifies a directory's contents only and DST exists as a directory.
// This should copy the contents of the SRC directory (but not the directory
// itself) into the DST directory. Ensure this works whether DST has a
// trailing path separator or not.
func TestCopyCaseJFSym(t *testing.T) {
tmpDirA, tmpDirB := getTestTempDirs(t)
defer removeAllPaths(tmpDirA, tmpDirB)
// Load A and B with some sample files and directories.
createSampleDir(t, tmpDirA)
createSampleDir(t, tmpDirB)
srcDir := joinTrailingSep(tmpDirA, "dirSymlink") + "."
linkTarget := filepath.Join(tmpDirA, "dir1")
dstDir := filepath.Join(tmpDirB, "dir5")
var err error
// first to create an empty dir
if err = os.MkdirAll(dstDir, os.FileMode(0755)); err != nil {
t.Fatalf("unable to make dstDir: %s", err)
}
if err = testCopyHelperFSym(t, srcDir, dstDir); err != nil {
t.Fatalf("unexpected error %T: %s", err, err)
}
if err = dirContentsEqual(t, dstDir, linkTarget); err != nil {
t.Fatal(err)
}
// Now try again but using a trailing path separator for dstDir.
if err = os.RemoveAll(dstDir); err != nil {
t.Fatalf("unable to remove dstDir: %s", err)
}
if err = os.MkdirAll(dstDir, os.FileMode(0755)); err != nil {
t.Fatalf("unable to make dstDir: %s", err)
}
dstDir = joinTrailingSep(tmpDirB, "dir5")
if err = testCopyHelperFSym(t, srcDir, dstDir); err != nil {
t.Fatalf("unexpected error %T: %s", err, err)
}
if err = dirContentsEqual(t, dstDir, linkTarget); err != nil {
t.Fatal(err)
}
}

View File

@@ -0,0 +1,11 @@
// +build !windows
package archive
import (
"path/filepath"
)
func normalizePath(path string) string {
return filepath.ToSlash(path)
}

View File

@@ -0,0 +1,9 @@
package archive
import (
"path/filepath"
)
func normalizePath(path string) string {
return filepath.FromSlash(path)
}

279
vendor/github.com/hyperhq/hypercli/pkg/archive/diff.go generated vendored Normal file
View File

@@ -0,0 +1,279 @@
package archive
import (
"archive/tar"
"fmt"
"io"
"io/ioutil"
"os"
"path/filepath"
"runtime"
"strings"
"github.com/Sirupsen/logrus"
"github.com/hyperhq/hypercli/pkg/idtools"
"github.com/hyperhq/hypercli/pkg/pools"
"github.com/hyperhq/hypercli/pkg/system"
)
// UnpackLayer unpack `layer` to a `dest`. The stream `layer` can be
// compressed or uncompressed.
// Returns the size in bytes of the contents of the layer.
func UnpackLayer(dest string, layer Reader, options *TarOptions) (size int64, err error) {
tr := tar.NewReader(layer)
trBuf := pools.BufioReader32KPool.Get(tr)
defer pools.BufioReader32KPool.Put(trBuf)
var dirs []*tar.Header
unpackedPaths := make(map[string]struct{})
if options == nil {
options = &TarOptions{}
}
if options.ExcludePatterns == nil {
options.ExcludePatterns = []string{}
}
remappedRootUID, remappedRootGID, err := idtools.GetRootUIDGID(options.UIDMaps, options.GIDMaps)
if err != nil {
return 0, err
}
aufsTempdir := ""
aufsHardlinks := make(map[string]*tar.Header)
if options == nil {
options = &TarOptions{}
}
// Iterate through the files in the archive.
for {
hdr, err := tr.Next()
if err == io.EOF {
// end of tar archive
break
}
if err != nil {
return 0, err
}
size += hdr.Size
// Normalize name, for safety and for a simple is-root check
hdr.Name = filepath.Clean(hdr.Name)
// Windows does not support filenames with colons in them. Ignore
// these files. This is not a problem though (although it might
// appear that it is). Let's suppose a client is running docker pull.
// The daemon it points to is Windows. Would it make sense for the
// client to be doing a docker pull Ubuntu for example (which has files
// with colons in the name under /usr/share/man/man3)? No, absolutely
// not as it would really only make sense that they were pulling a
// Windows image. However, for development, it is necessary to be able
// to pull Linux images which are in the repository.
//
// TODO Windows. Once the registry is aware of what images are Windows-
// specific or Linux-specific, this warning should be changed to an error
// to cater for the situation where someone does manage to upload a Linux
// image but have it tagged as Windows inadvertently.
if runtime.GOOS == "windows" {
if strings.Contains(hdr.Name, ":") {
logrus.Warnf("Windows: Ignoring %s (is this a Linux image?)", hdr.Name)
continue
}
}
// Note as these operations are platform specific, so must the slash be.
if !strings.HasSuffix(hdr.Name, string(os.PathSeparator)) {
// Not the root directory, ensure that the parent directory exists.
// This happened in some tests where an image had a tarfile without any
// parent directories.
parent := filepath.Dir(hdr.Name)
parentPath := filepath.Join(dest, parent)
if _, err := os.Lstat(parentPath); err != nil && os.IsNotExist(err) {
err = system.MkdirAll(parentPath, 0600)
if err != nil {
return 0, err
}
}
}
// Skip AUFS metadata dirs
if strings.HasPrefix(hdr.Name, WhiteoutMetaPrefix) {
// Regular files inside /.wh..wh.plnk can be used as hardlink targets
// We don't want this directory, but we need the files in them so that
// such hardlinks can be resolved.
if strings.HasPrefix(hdr.Name, WhiteoutLinkDir) && hdr.Typeflag == tar.TypeReg {
basename := filepath.Base(hdr.Name)
aufsHardlinks[basename] = hdr
if aufsTempdir == "" {
if aufsTempdir, err = ioutil.TempDir("", "dockerplnk"); err != nil {
return 0, err
}
defer os.RemoveAll(aufsTempdir)
}
if err := createTarFile(filepath.Join(aufsTempdir, basename), dest, hdr, tr, true, nil); err != nil {
return 0, err
}
}
if hdr.Name != WhiteoutOpaqueDir {
continue
}
}
path := filepath.Join(dest, hdr.Name)
rel, err := filepath.Rel(dest, path)
if err != nil {
return 0, err
}
// Note as these operations are platform specific, so must the slash be.
if strings.HasPrefix(rel, ".."+string(os.PathSeparator)) {
return 0, breakoutError(fmt.Errorf("%q is outside of %q", hdr.Name, dest))
}
base := filepath.Base(path)
if strings.HasPrefix(base, WhiteoutPrefix) {
dir := filepath.Dir(path)
if base == WhiteoutOpaqueDir {
_, err := os.Lstat(dir)
if err != nil {
return 0, err
}
err = filepath.Walk(dir, func(path string, info os.FileInfo, err error) error {
if err != nil {
if os.IsNotExist(err) {
err = nil // parent was deleted
}
return err
}
if path == dir {
return nil
}
if _, exists := unpackedPaths[path]; !exists {
err := os.RemoveAll(path)
return err
}
return nil
})
if err != nil {
return 0, err
}
} else {
originalBase := base[len(WhiteoutPrefix):]
originalPath := filepath.Join(dir, originalBase)
if err := os.RemoveAll(originalPath); err != nil {
return 0, err
}
}
} else {
// If path exits we almost always just want to remove and replace it.
// The only exception is when it is a directory *and* the file from
// the layer is also a directory. Then we want to merge them (i.e.
// just apply the metadata from the layer).
if fi, err := os.Lstat(path); err == nil {
if !(fi.IsDir() && hdr.Typeflag == tar.TypeDir) {
if err := os.RemoveAll(path); err != nil {
return 0, err
}
}
}
trBuf.Reset(tr)
srcData := io.Reader(trBuf)
srcHdr := hdr
// Hard links into /.wh..wh.plnk don't work, as we don't extract that directory, so
// we manually retarget these into the temporary files we extracted them into
if hdr.Typeflag == tar.TypeLink && strings.HasPrefix(filepath.Clean(hdr.Linkname), WhiteoutLinkDir) {
linkBasename := filepath.Base(hdr.Linkname)
srcHdr = aufsHardlinks[linkBasename]
if srcHdr == nil {
return 0, fmt.Errorf("Invalid aufs hardlink")
}
tmpFile, err := os.Open(filepath.Join(aufsTempdir, linkBasename))
if err != nil {
return 0, err
}
defer tmpFile.Close()
srcData = tmpFile
}
// if the options contain a uid & gid maps, convert header uid/gid
// entries using the maps such that lchown sets the proper mapped
// uid/gid after writing the file. We only perform this mapping if
// the file isn't already owned by the remapped root UID or GID, as
// that specific uid/gid has no mapping from container -> host, and
// those files already have the proper ownership for inside the
// container.
if srcHdr.Uid != remappedRootUID {
xUID, err := idtools.ToHost(srcHdr.Uid, options.UIDMaps)
if err != nil {
return 0, err
}
srcHdr.Uid = xUID
}
if srcHdr.Gid != remappedRootGID {
xGID, err := idtools.ToHost(srcHdr.Gid, options.GIDMaps)
if err != nil {
return 0, err
}
srcHdr.Gid = xGID
}
if err := createTarFile(path, dest, srcHdr, srcData, true, nil); err != nil {
return 0, err
}
// Directory mtimes must be handled at the end to avoid further
// file creation in them to modify the directory mtime
if hdr.Typeflag == tar.TypeDir {
dirs = append(dirs, hdr)
}
unpackedPaths[path] = struct{}{}
}
}
for _, hdr := range dirs {
path := filepath.Join(dest, hdr.Name)
if err := system.Chtimes(path, hdr.AccessTime, hdr.ModTime); err != nil {
return 0, err
}
}
return size, nil
}
// ApplyLayer parses a diff in the standard layer format from `layer`,
// and applies it to the directory `dest`. The stream `layer` can be
// compressed or uncompressed.
// Returns the size in bytes of the contents of the layer.
func ApplyLayer(dest string, layer Reader) (int64, error) {
return applyLayerHandler(dest, layer, &TarOptions{}, true)
}
// ApplyUncompressedLayer parses a diff in the standard layer format from
// `layer`, and applies it to the directory `dest`. The stream `layer`
// can only be uncompressed.
// Returns the size in bytes of the contents of the layer.
func ApplyUncompressedLayer(dest string, layer Reader, options *TarOptions) (int64, error) {
return applyLayerHandler(dest, layer, options, false)
}
// do the bulk load of ApplyLayer, but allow for not calling DecompressStream
func applyLayerHandler(dest string, layer Reader, options *TarOptions, decompress bool) (int64, error) {
dest = filepath.Clean(dest)
// We need to be able to set any perms
oldmask, err := system.Umask(0)
if err != nil {
return 0, err
}
defer system.Umask(oldmask) // ignore err, ErrNotSupportedPlatform
if decompress {
layer, err = DecompressStream(layer)
if err != nil {
return 0, err
}
}
return UnpackLayer(dest, layer, options)
}

View File

@@ -0,0 +1,370 @@
package archive
import (
"archive/tar"
"io"
"io/ioutil"
"os"
"path/filepath"
"reflect"
"testing"
"github.com/hyperhq/hypercli/pkg/ioutils"
)
func TestApplyLayerInvalidFilenames(t *testing.T) {
for i, headers := range [][]*tar.Header{
{
{
Name: "../victim/dotdot",
Typeflag: tar.TypeReg,
Mode: 0644,
},
},
{
{
// Note the leading slash
Name: "/../victim/slash-dotdot",
Typeflag: tar.TypeReg,
Mode: 0644,
},
},
} {
if err := testBreakout("applylayer", "docker-TestApplyLayerInvalidFilenames", headers); err != nil {
t.Fatalf("i=%d. %v", i, err)
}
}
}
func TestApplyLayerInvalidHardlink(t *testing.T) {
for i, headers := range [][]*tar.Header{
{ // try reading victim/hello (../)
{
Name: "dotdot",
Typeflag: tar.TypeLink,
Linkname: "../victim/hello",
Mode: 0644,
},
},
{ // try reading victim/hello (/../)
{
Name: "slash-dotdot",
Typeflag: tar.TypeLink,
// Note the leading slash
Linkname: "/../victim/hello",
Mode: 0644,
},
},
{ // try writing victim/file
{
Name: "loophole-victim",
Typeflag: tar.TypeLink,
Linkname: "../victim",
Mode: 0755,
},
{
Name: "loophole-victim/file",
Typeflag: tar.TypeReg,
Mode: 0644,
},
},
{ // try reading victim/hello (hardlink, symlink)
{
Name: "loophole-victim",
Typeflag: tar.TypeLink,
Linkname: "../victim",
Mode: 0755,
},
{
Name: "symlink",
Typeflag: tar.TypeSymlink,
Linkname: "loophole-victim/hello",
Mode: 0644,
},
},
{ // Try reading victim/hello (hardlink, hardlink)
{
Name: "loophole-victim",
Typeflag: tar.TypeLink,
Linkname: "../victim",
Mode: 0755,
},
{
Name: "hardlink",
Typeflag: tar.TypeLink,
Linkname: "loophole-victim/hello",
Mode: 0644,
},
},
{ // Try removing victim directory (hardlink)
{
Name: "loophole-victim",
Typeflag: tar.TypeLink,
Linkname: "../victim",
Mode: 0755,
},
{
Name: "loophole-victim",
Typeflag: tar.TypeReg,
Mode: 0644,
},
},
} {
if err := testBreakout("applylayer", "docker-TestApplyLayerInvalidHardlink", headers); err != nil {
t.Fatalf("i=%d. %v", i, err)
}
}
}
func TestApplyLayerInvalidSymlink(t *testing.T) {
for i, headers := range [][]*tar.Header{
{ // try reading victim/hello (../)
{
Name: "dotdot",
Typeflag: tar.TypeSymlink,
Linkname: "../victim/hello",
Mode: 0644,
},
},
{ // try reading victim/hello (/../)
{
Name: "slash-dotdot",
Typeflag: tar.TypeSymlink,
// Note the leading slash
Linkname: "/../victim/hello",
Mode: 0644,
},
},
{ // try writing victim/file
{
Name: "loophole-victim",
Typeflag: tar.TypeSymlink,
Linkname: "../victim",
Mode: 0755,
},
{
Name: "loophole-victim/file",
Typeflag: tar.TypeReg,
Mode: 0644,
},
},
{ // try reading victim/hello (symlink, symlink)
{
Name: "loophole-victim",
Typeflag: tar.TypeSymlink,
Linkname: "../victim",
Mode: 0755,
},
{
Name: "symlink",
Typeflag: tar.TypeSymlink,
Linkname: "loophole-victim/hello",
Mode: 0644,
},
},
{ // try reading victim/hello (symlink, hardlink)
{
Name: "loophole-victim",
Typeflag: tar.TypeSymlink,
Linkname: "../victim",
Mode: 0755,
},
{
Name: "hardlink",
Typeflag: tar.TypeLink,
Linkname: "loophole-victim/hello",
Mode: 0644,
},
},
{ // try removing victim directory (symlink)
{
Name: "loophole-victim",
Typeflag: tar.TypeSymlink,
Linkname: "../victim",
Mode: 0755,
},
{
Name: "loophole-victim",
Typeflag: tar.TypeReg,
Mode: 0644,
},
},
} {
if err := testBreakout("applylayer", "docker-TestApplyLayerInvalidSymlink", headers); err != nil {
t.Fatalf("i=%d. %v", i, err)
}
}
}
func TestApplyLayerWhiteouts(t *testing.T) {
wd, err := ioutil.TempDir("", "graphdriver-test-whiteouts")
if err != nil {
return
}
defer os.RemoveAll(wd)
base := []string{
".baz",
"bar/",
"bar/bax",
"bar/bay/",
"baz",
"foo/",
"foo/.abc",
"foo/.bcd/",
"foo/.bcd/a",
"foo/cde/",
"foo/cde/def",
"foo/cde/efg",
"foo/fgh",
"foobar",
}
type tcase struct {
change, expected []string
}
tcases := []tcase{
{
base,
base,
},
{
[]string{
".bay",
".wh.baz",
"foo/",
"foo/.bce",
"foo/.wh..wh..opq",
"foo/cde/",
"foo/cde/efg",
},
[]string{
".bay",
".baz",
"bar/",
"bar/bax",
"bar/bay/",
"foo/",
"foo/.bce",
"foo/cde/",
"foo/cde/efg",
"foobar",
},
},
{
[]string{
".bay",
".wh..baz",
".wh.foobar",
"foo/",
"foo/.abc",
"foo/.wh.cde",
"bar/",
},
[]string{
".bay",
"bar/",
"bar/bax",
"bar/bay/",
"foo/",
"foo/.abc",
"foo/.bce",
},
},
{
[]string{
".abc",
".wh..wh..opq",
"foobar",
},
[]string{
".abc",
"foobar",
},
},
}
for i, tc := range tcases {
l, err := makeTestLayer(tc.change)
if err != nil {
t.Fatal(err)
}
_, err = UnpackLayer(wd, l, nil)
if err != nil {
t.Fatal(err)
}
err = l.Close()
if err != nil {
t.Fatal(err)
}
paths, err := readDirContents(wd)
if err != nil {
t.Fatal(err)
}
if !reflect.DeepEqual(tc.expected, paths) {
t.Fatalf("invalid files for layer %d: expected %q, got %q", i, tc.expected, paths)
}
}
}
func makeTestLayer(paths []string) (rc io.ReadCloser, err error) {
tmpDir, err := ioutil.TempDir("", "graphdriver-test-mklayer")
if err != nil {
return
}
defer func() {
if err != nil {
os.RemoveAll(tmpDir)
}
}()
for _, p := range paths {
if p[len(p)-1] == filepath.Separator {
if err = os.MkdirAll(filepath.Join(tmpDir, p), 0700); err != nil {
return
}
} else {
if err = ioutil.WriteFile(filepath.Join(tmpDir, p), nil, 0600); err != nil {
return
}
}
}
archive, err := Tar(tmpDir, Uncompressed)
if err != nil {
return
}
return ioutils.NewReadCloserWrapper(archive, func() error {
err := archive.Close()
os.RemoveAll(tmpDir)
return err
}), nil
}
func readDirContents(root string) ([]string, error) {
var files []string
err := filepath.Walk(root, func(path string, info os.FileInfo, err error) error {
if err != nil {
return err
}
if path == root {
return nil
}
rel, err := filepath.Rel(root, path)
if err != nil {
return err
}
if info.IsDir() {
rel = rel + "/"
}
files = append(files, rel)
return nil
})
if err != nil {
return nil, err
}
return files, nil
}

View File

@@ -0,0 +1,97 @@
// +build ignore
// Simple tool to create an archive stream from an old and new directory
//
// By default it will stream the comparison of two temporary directories with junk files
package main
import (
"flag"
"fmt"
"io"
"io/ioutil"
"os"
"path"
"github.com/Sirupsen/logrus"
"github.com/hyperhq/hypercli/pkg/archive"
)
var (
flDebug = flag.Bool("D", false, "debugging output")
flNewDir = flag.String("newdir", "", "")
flOldDir = flag.String("olddir", "", "")
log = logrus.New()
)
func main() {
flag.Usage = func() {
fmt.Println("Produce a tar from comparing two directory paths. By default a demo tar is created of around 200 files (including hardlinks)")
fmt.Printf("%s [OPTIONS]\n", os.Args[0])
flag.PrintDefaults()
}
flag.Parse()
log.Out = os.Stderr
if (len(os.Getenv("DEBUG")) > 0) || *flDebug {
logrus.SetLevel(logrus.DebugLevel)
}
var newDir, oldDir string
if len(*flNewDir) == 0 {
var err error
newDir, err = ioutil.TempDir("", "docker-test-newDir")
if err != nil {
log.Fatal(err)
}
defer os.RemoveAll(newDir)
if _, err := prepareUntarSourceDirectory(100, newDir, true); err != nil {
log.Fatal(err)
}
} else {
newDir = *flNewDir
}
if len(*flOldDir) == 0 {
oldDir, err := ioutil.TempDir("", "docker-test-oldDir")
if err != nil {
log.Fatal(err)
}
defer os.RemoveAll(oldDir)
} else {
oldDir = *flOldDir
}
changes, err := archive.ChangesDirs(newDir, oldDir)
if err != nil {
log.Fatal(err)
}
a, err := archive.ExportChanges(newDir, changes)
if err != nil {
log.Fatal(err)
}
defer a.Close()
i, err := io.Copy(os.Stdout, a)
if err != nil && err != io.EOF {
log.Fatal(err)
}
fmt.Fprintf(os.Stderr, "wrote archive of %d bytes", i)
}
func prepareUntarSourceDirectory(numberOfFiles int, targetPath string, makeLinks bool) (int, error) {
fileData := []byte("fooo")
for n := 0; n < numberOfFiles; n++ {
fileName := fmt.Sprintf("file-%d", n)
if err := ioutil.WriteFile(path.Join(targetPath, fileName), fileData, 0700); err != nil {
return 0, err
}
if makeLinks {
if err := os.Link(path.Join(targetPath, fileName), path.Join(targetPath, fileName+"-link")); err != nil {
return 0, err
}
}
}
totalSize := numberOfFiles * len(fileData)
return totalSize, nil
}

Binary file not shown.

View File

@@ -0,0 +1,16 @@
package archive
import (
"syscall"
"time"
)
func timeToTimespec(time time.Time) (ts syscall.Timespec) {
if time.IsZero() {
// Return UTIME_OMIT special value
ts.Sec = 0
ts.Nsec = ((1 << 30) - 2)
return
}
return syscall.NsecToTimespec(time.UnixNano())
}

View File

@@ -0,0 +1,16 @@
// +build !linux
package archive
import (
"syscall"
"time"
)
func timeToTimespec(time time.Time) (ts syscall.Timespec) {
nsec := int64(0)
if !time.IsZero() {
nsec = time.UnixNano()
}
return syscall.NsecToTimespec(nsec)
}

View File

@@ -0,0 +1,166 @@
package archive
import (
"archive/tar"
"bytes"
"fmt"
"io"
"io/ioutil"
"os"
"path/filepath"
"time"
)
var testUntarFns = map[string]func(string, io.Reader) error{
"untar": func(dest string, r io.Reader) error {
return Untar(r, dest, nil)
},
"applylayer": func(dest string, r io.Reader) error {
_, err := ApplyLayer(dest, Reader(r))
return err
},
}
// testBreakout is a helper function that, within the provided `tmpdir` directory,
// creates a `victim` folder with a generated `hello` file in it.
// `untar` extracts to a directory named `dest`, the tar file created from `headers`.
//
// Here are the tested scenarios:
// - removed `victim` folder (write)
// - removed files from `victim` folder (write)
// - new files in `victim` folder (write)
// - modified files in `victim` folder (write)
// - file in `dest` with same content as `victim/hello` (read)
//
// When using testBreakout make sure you cover one of the scenarios listed above.
func testBreakout(untarFn string, tmpdir string, headers []*tar.Header) error {
tmpdir, err := ioutil.TempDir("", tmpdir)
if err != nil {
return err
}
defer os.RemoveAll(tmpdir)
dest := filepath.Join(tmpdir, "dest")
if err := os.Mkdir(dest, 0755); err != nil {
return err
}
victim := filepath.Join(tmpdir, "victim")
if err := os.Mkdir(victim, 0755); err != nil {
return err
}
hello := filepath.Join(victim, "hello")
helloData, err := time.Now().MarshalText()
if err != nil {
return err
}
if err := ioutil.WriteFile(hello, helloData, 0644); err != nil {
return err
}
helloStat, err := os.Stat(hello)
if err != nil {
return err
}
reader, writer := io.Pipe()
go func() {
t := tar.NewWriter(writer)
for _, hdr := range headers {
t.WriteHeader(hdr)
}
t.Close()
}()
untar := testUntarFns[untarFn]
if untar == nil {
return fmt.Errorf("could not find untar function %q in testUntarFns", untarFn)
}
if err := untar(dest, reader); err != nil {
if _, ok := err.(breakoutError); !ok {
// If untar returns an error unrelated to an archive breakout,
// then consider this an unexpected error and abort.
return err
}
// Here, untar detected the breakout.
// Let's move on verifying that indeed there was no breakout.
fmt.Printf("breakoutError: %v\n", err)
}
// Check victim folder
f, err := os.Open(victim)
if err != nil {
// codepath taken if victim folder was removed
return fmt.Errorf("archive breakout: error reading %q: %v", victim, err)
}
defer f.Close()
// Check contents of victim folder
//
// We are only interested in getting 2 files from the victim folder, because if all is well
// we expect only one result, the `hello` file. If there is a second result, it cannot
// hold the same name `hello` and we assume that a new file got created in the victim folder.
// That is enough to detect an archive breakout.
names, err := f.Readdirnames(2)
if err != nil {
// codepath taken if victim is not a folder
return fmt.Errorf("archive breakout: error reading directory content of %q: %v", victim, err)
}
for _, name := range names {
if name != "hello" {
// codepath taken if new file was created in victim folder
return fmt.Errorf("archive breakout: new file %q", name)
}
}
// Check victim/hello
f, err = os.Open(hello)
if err != nil {
// codepath taken if read permissions were removed
return fmt.Errorf("archive breakout: could not lstat %q: %v", hello, err)
}
defer f.Close()
b, err := ioutil.ReadAll(f)
if err != nil {
return err
}
fi, err := f.Stat()
if err != nil {
return err
}
if helloStat.IsDir() != fi.IsDir() ||
// TODO: cannot check for fi.ModTime() change
helloStat.Mode() != fi.Mode() ||
helloStat.Size() != fi.Size() ||
!bytes.Equal(helloData, b) {
// codepath taken if hello has been modified
return fmt.Errorf("archive breakout: file %q has been modified. Contents: expected=%q, got=%q. FileInfo: expected=%#v, got=%#v", hello, helloData, b, helloStat, fi)
}
// Check that nothing in dest/ has the same content as victim/hello.
// Since victim/hello was generated with time.Now(), it is safe to assume
// that any file whose content matches exactly victim/hello, managed somehow
// to access victim/hello.
return filepath.Walk(dest, func(path string, info os.FileInfo, err error) error {
if info.IsDir() {
if err != nil {
// skip directory if error
return filepath.SkipDir
}
// enter directory
return nil
}
if err != nil {
// skip file if error
return nil
}
b, err := ioutil.ReadFile(path)
if err != nil {
// Houston, we have a problem. Aborting (space)walk.
return err
}
if bytes.Equal(helloData, b) {
return fmt.Errorf("archive breakout: file %q has been accessed via %q", hello, path)
}
return nil
})
}

View File

@@ -0,0 +1,23 @@
package archive
// Whiteouts are files with a special meaning for the layered filesystem.
// Docker uses AUFS whiteout files inside exported archives. In other
// filesystems these files are generated/handled on tar creation/extraction.
// WhiteoutPrefix prefix means file is a whiteout. If this is followed by a
// filename this means that file has been removed from the base layer.
const WhiteoutPrefix = ".wh."
// WhiteoutMetaPrefix prefix means whiteout has a special meaning and is not
// for removing an actual file. Normally these files are excluded from exported
// archives.
const WhiteoutMetaPrefix = WhiteoutPrefix + WhiteoutPrefix
// WhiteoutLinkDir is a directory AUFS uses for storing hardlink links to other
// layers. Normally these should not go into exported archives and all changed
// hardlinks should be copied to the top layer.
const WhiteoutLinkDir = WhiteoutMetaPrefix + "plnk"
// WhiteoutOpaqueDir file means directory has been made opaque - meaning
// readdir calls to this directory do not follow to lower layers.
const WhiteoutOpaqueDir = WhiteoutMetaPrefix + ".opq"

59
vendor/github.com/hyperhq/hypercli/pkg/archive/wrap.go generated vendored Normal file
View File

@@ -0,0 +1,59 @@
package archive
import (
"archive/tar"
"bytes"
"io/ioutil"
)
// Generate generates a new archive from the content provided
// as input.
//
// `files` is a sequence of path/content pairs. A new file is
// added to the archive for each pair.
// If the last pair is incomplete, the file is created with an
// empty content. For example:
//
// Generate("foo.txt", "hello world", "emptyfile")
//
// The above call will return an archive with 2 files:
// * ./foo.txt with content "hello world"
// * ./empty with empty content
//
// FIXME: stream content instead of buffering
// FIXME: specify permissions and other archive metadata
func Generate(input ...string) (Archive, error) {
files := parseStringPairs(input...)
buf := new(bytes.Buffer)
tw := tar.NewWriter(buf)
for _, file := range files {
name, content := file[0], file[1]
hdr := &tar.Header{
Name: name,
Size: int64(len(content)),
}
if err := tw.WriteHeader(hdr); err != nil {
return nil, err
}
if _, err := tw.Write([]byte(content)); err != nil {
return nil, err
}
}
if err := tw.Close(); err != nil {
return nil, err
}
return ioutil.NopCloser(buf), nil
}
func parseStringPairs(input ...string) (output [][2]string) {
output = make([][2]string, 0, len(input)/2+1)
for i := 0; i < len(input); i += 2 {
var pair [2]string
pair[0] = input[i]
if i+1 < len(input) {
pair[1] = input[i+1]
}
output = append(output, pair)
}
return
}

View File

@@ -0,0 +1,98 @@
package archive
import (
"archive/tar"
"bytes"
"io"
"testing"
)
func TestGenerateEmptyFile(t *testing.T) {
archive, err := Generate("emptyFile")
if err != nil {
t.Fatal(err)
}
if archive == nil {
t.Fatal("The generated archive should not be nil.")
}
expectedFiles := [][]string{
{"emptyFile", ""},
}
tr := tar.NewReader(archive)
actualFiles := make([][]string, 0, 10)
i := 0
for {
hdr, err := tr.Next()
if err == io.EOF {
break
}
if err != nil {
t.Fatal(err)
}
buf := new(bytes.Buffer)
buf.ReadFrom(tr)
content := buf.String()
actualFiles = append(actualFiles, []string{hdr.Name, content})
i++
}
if len(actualFiles) != len(expectedFiles) {
t.Fatalf("Number of expected file %d, got %d.", len(expectedFiles), len(actualFiles))
}
for i := 0; i < len(expectedFiles); i++ {
actual := actualFiles[i]
expected := expectedFiles[i]
if actual[0] != expected[0] {
t.Fatalf("Expected name '%s', Actual name '%s'", expected[0], actual[0])
}
if actual[1] != expected[1] {
t.Fatalf("Expected content '%s', Actual content '%s'", expected[1], actual[1])
}
}
}
func TestGenerateWithContent(t *testing.T) {
archive, err := Generate("file", "content")
if err != nil {
t.Fatal(err)
}
if archive == nil {
t.Fatal("The generated archive should not be nil.")
}
expectedFiles := [][]string{
{"file", "content"},
}
tr := tar.NewReader(archive)
actualFiles := make([][]string, 0, 10)
i := 0
for {
hdr, err := tr.Next()
if err == io.EOF {
break
}
if err != nil {
t.Fatal(err)
}
buf := new(bytes.Buffer)
buf.ReadFrom(tr)
content := buf.String()
actualFiles = append(actualFiles, []string{hdr.Name, content})
i++
}
if len(actualFiles) != len(expectedFiles) {
t.Fatalf("Number of expected file %d, got %d.", len(expectedFiles), len(actualFiles))
}
for i := 0; i < len(expectedFiles); i++ {
actual := actualFiles[i]
expected := expectedFiles[i]
if actual[0] != expected[0] {
t.Fatalf("Expected name '%s', Actual name '%s'", expected[0], actual[0])
}
if actual[1] != expected[1] {
t.Fatalf("Expected content '%s', Actual content '%s'", expected[1], actual[1])
}
}
}

View File

@@ -0,0 +1,54 @@
package authorization
const (
// AuthZApiRequest is the url for daemon request authorization
AuthZApiRequest = "AuthZPlugin.AuthZReq"
// AuthZApiResponse is the url for daemon response authorization
AuthZApiResponse = "AuthZPlugin.AuthZRes"
// AuthZApiImplements is the name of the interface all AuthZ plugins implement
AuthZApiImplements = "authz"
)
// Request holds data required for authZ plugins
type Request struct {
// User holds the user extracted by AuthN mechanism
User string `json:"User,omitempty"`
// UserAuthNMethod holds the mechanism used to extract user details (e.g., krb)
UserAuthNMethod string `json:"UserAuthNMethod,omitempty"`
// RequestMethod holds the HTTP method (GET/POST/PUT)
RequestMethod string `json:"RequestMethod,omitempty"`
// RequestUri holds the full HTTP uri (e.g., /v1.21/version)
RequestURI string `json:"RequestUri,omitempty"`
// RequestBody stores the raw request body sent to the docker daemon
RequestBody []byte `json:"RequestBody,omitempty"`
// RequestHeaders stores the raw request headers sent to the docker daemon
RequestHeaders map[string]string `json:"RequestHeaders,omitempty"`
// ResponseStatusCode stores the status code returned from docker daemon
ResponseStatusCode int `json:"ResponseStatusCode,omitempty"`
// ResponseBody stores the raw response body sent from docker daemon
ResponseBody []byte `json:"ResponseBody,omitempty"`
// ResponseHeaders stores the response headers sent to the docker daemon
ResponseHeaders map[string]string `json:"ResponseHeaders,omitempty"`
}
// Response represents authZ plugin response
type Response struct {
// Allow indicating whether the user is allowed or not
Allow bool `json:"Allow"`
// Msg stores the authorization message
Msg string `json:"Msg,omitempty"`
// Err stores a message in case there's an error
Err string `json:"Err,omitempty"`
}

View File

@@ -0,0 +1,168 @@
package authorization
import (
"bufio"
"bytes"
"fmt"
"io"
"net/http"
"strings"
"github.com/Sirupsen/logrus"
"github.com/hyperhq/hypercli/pkg/ioutils"
)
const maxBodySize = 1048576 // 1MB
// NewCtx creates new authZ context, it is used to store authorization information related to a specific docker
// REST http session
// A context provides two method:
// Authenticate Request:
// Call authZ plugins with current REST request and AuthN response
// Request contains full HTTP packet sent to the docker daemon
// https://docs.docker.com/reference/api/docker_remote_api/
//
// Authenticate Response:
// Call authZ plugins with full info about current REST request, REST response and AuthN response
// The response from this method may contains content that overrides the daemon response
// This allows authZ plugins to filter privileged content
//
// If multiple authZ plugins are specified, the block/allow decision is based on ANDing all plugin results
// For response manipulation, the response from each plugin is piped between plugins. Plugin execution order
// is determined according to daemon parameters
func NewCtx(authZPlugins []Plugin, user, userAuthNMethod, requestMethod, requestURI string) *Ctx {
return &Ctx{
plugins: authZPlugins,
user: user,
userAuthNMethod: userAuthNMethod,
requestMethod: requestMethod,
requestURI: requestURI,
}
}
// Ctx stores a a single request-response interaction context
type Ctx struct {
user string
userAuthNMethod string
requestMethod string
requestURI string
plugins []Plugin
// authReq stores the cached request object for the current transaction
authReq *Request
}
// AuthZRequest authorized the request to the docker daemon using authZ plugins
func (ctx *Ctx) AuthZRequest(w http.ResponseWriter, r *http.Request) error {
var body []byte
if sendBody(ctx.requestURI, r.Header) {
if r.ContentLength < maxBodySize {
var err error
body, r.Body, err = drainBody(r.Body)
if err != nil {
return err
}
}
}
var h bytes.Buffer
if err := r.Header.Write(&h); err != nil {
return err
}
ctx.authReq = &Request{
User: ctx.user,
UserAuthNMethod: ctx.userAuthNMethod,
RequestMethod: ctx.requestMethod,
RequestURI: ctx.requestURI,
RequestBody: body,
RequestHeaders: headers(r.Header),
}
for _, plugin := range ctx.plugins {
logrus.Debugf("AuthZ request using plugin %s", plugin.Name())
authRes, err := plugin.AuthZRequest(ctx.authReq)
if err != nil {
return fmt.Errorf("plugin %s failed with error: %s", plugin.Name(), err)
}
if !authRes.Allow {
return fmt.Errorf("authorization denied by plugin %s: %s", plugin.Name(), authRes.Msg)
}
}
return nil
}
// AuthZResponse authorized and manipulates the response from docker daemon using authZ plugins
func (ctx *Ctx) AuthZResponse(rm ResponseModifier, r *http.Request) error {
ctx.authReq.ResponseStatusCode = rm.StatusCode()
ctx.authReq.ResponseHeaders = headers(rm.Header())
if sendBody(ctx.requestURI, rm.Header()) {
ctx.authReq.ResponseBody = rm.RawBody()
}
for _, plugin := range ctx.plugins {
logrus.Debugf("AuthZ response using plugin %s", plugin.Name())
authRes, err := plugin.AuthZResponse(ctx.authReq)
if err != nil {
return fmt.Errorf("plugin %s failed with error: %s", plugin.Name(), err)
}
if !authRes.Allow {
return fmt.Errorf("authorization denied by plugin %s: %s", plugin.Name(), authRes.Msg)
}
}
rm.Flush()
return nil
}
// drainBody dump the body, it reads the body data into memory and
// see go sources /go/src/net/http/httputil/dump.go
func drainBody(body io.ReadCloser) ([]byte, io.ReadCloser, error) {
bufReader := bufio.NewReaderSize(body, maxBodySize)
newBody := ioutils.NewReadCloserWrapper(bufReader, func() error { return body.Close() })
data, err := bufReader.Peek(maxBodySize)
if err != io.EOF {
// This means the request is larger than our max
if err == bufio.ErrBufferFull {
return nil, newBody, nil
}
// This means we had an error reading
return nil, nil, err
}
return data, newBody, nil
}
// sendBody returns true when request/response body should be sent to AuthZPlugin
func sendBody(url string, header http.Header) bool {
// Skip body for auth endpoint
if strings.HasSuffix(url, "/auth") {
return false
}
// body is sent only for text or json messages
v := header.Get("Content-Type")
return strings.HasPrefix(v, "text/") || v == "application/json"
}
// headers returns flatten version of the http headers excluding authorization
func headers(header http.Header) map[string]string {
v := make(map[string]string, 0)
for k, values := range header {
// Skip authorization headers
if strings.EqualFold(k, "Authorization") || strings.EqualFold(k, "X-Registry-Config") || strings.EqualFold(k, "X-Registry-Auth") {
continue
}
for _, val := range values {
v[k] = val
}
}
return v
}

View File

@@ -0,0 +1,233 @@
package authorization
import (
"encoding/json"
"io/ioutil"
"log"
"net"
"net/http"
"net/http/httptest"
"os"
"path"
"reflect"
"testing"
"github.com/hyperhq/hypercli/pkg/plugins"
"github.com/docker/go-connections/tlsconfig"
"github.com/gorilla/mux"
)
const pluginAddress = "authzplugin.sock"
func TestAuthZRequestPluginError(t *testing.T) {
server := authZPluginTestServer{t: t}
go server.start()
defer server.stop()
authZPlugin := createTestPlugin(t)
request := Request{
User: "user",
RequestBody: []byte("sample body"),
RequestURI: "www.authz.com",
RequestMethod: "GET",
RequestHeaders: map[string]string{"header": "value"},
}
server.replayResponse = Response{
Err: "an error",
}
actualResponse, err := authZPlugin.AuthZRequest(&request)
if err != nil {
t.Fatalf("Failed to authorize request %v", err)
}
if !reflect.DeepEqual(server.replayResponse, *actualResponse) {
t.Fatalf("Response must be equal")
}
if !reflect.DeepEqual(request, server.recordedRequest) {
t.Fatalf("Requests must be equal")
}
}
func TestAuthZRequestPlugin(t *testing.T) {
server := authZPluginTestServer{t: t}
go server.start()
defer server.stop()
authZPlugin := createTestPlugin(t)
request := Request{
User: "user",
RequestBody: []byte("sample body"),
RequestURI: "www.authz.com",
RequestMethod: "GET",
RequestHeaders: map[string]string{"header": "value"},
}
server.replayResponse = Response{
Allow: true,
Msg: "Sample message",
}
actualResponse, err := authZPlugin.AuthZRequest(&request)
if err != nil {
t.Fatalf("Failed to authorize request %v", err)
}
if !reflect.DeepEqual(server.replayResponse, *actualResponse) {
t.Fatalf("Response must be equal")
}
if !reflect.DeepEqual(request, server.recordedRequest) {
t.Fatalf("Requests must be equal")
}
}
func TestAuthZResponsePlugin(t *testing.T) {
server := authZPluginTestServer{t: t}
go server.start()
defer server.stop()
authZPlugin := createTestPlugin(t)
request := Request{
User: "user",
RequestBody: []byte("sample body"),
}
server.replayResponse = Response{
Allow: true,
Msg: "Sample message",
}
actualResponse, err := authZPlugin.AuthZResponse(&request)
if err != nil {
t.Fatalf("Failed to authorize request %v", err)
}
if !reflect.DeepEqual(server.replayResponse, *actualResponse) {
t.Fatalf("Response must be equal")
}
if !reflect.DeepEqual(request, server.recordedRequest) {
t.Fatalf("Requests must be equal")
}
}
func TestResponseModifier(t *testing.T) {
r := httptest.NewRecorder()
m := NewResponseModifier(r)
m.Header().Set("h1", "v1")
m.Write([]byte("body"))
m.WriteHeader(500)
m.Flush()
if r.Header().Get("h1") != "v1" {
t.Fatalf("Header value must exists %s", r.Header().Get("h1"))
}
if !reflect.DeepEqual(r.Body.Bytes(), []byte("body")) {
t.Fatalf("Body value must exists %s", r.Body.Bytes())
}
if r.Code != 500 {
t.Fatalf("Status code must be correct %d", r.Code)
}
}
func TestResponseModifierOverride(t *testing.T) {
r := httptest.NewRecorder()
m := NewResponseModifier(r)
m.Header().Set("h1", "v1")
m.Write([]byte("body"))
m.WriteHeader(500)
overrideHeader := make(http.Header)
overrideHeader.Add("h1", "v2")
overrideHeaderBytes, err := json.Marshal(overrideHeader)
if err != nil {
t.Fatalf("override header failed %v", err)
}
m.OverrideHeader(overrideHeaderBytes)
m.OverrideBody([]byte("override body"))
m.OverrideStatusCode(404)
m.Flush()
if r.Header().Get("h1") != "v2" {
t.Fatalf("Header value must exists %s", r.Header().Get("h1"))
}
if !reflect.DeepEqual(r.Body.Bytes(), []byte("override body")) {
t.Fatalf("Body value must exists %s", r.Body.Bytes())
}
if r.Code != 404 {
t.Fatalf("Status code must be correct %d", r.Code)
}
}
// createTestPlugin creates a new sample authorization plugin
func createTestPlugin(t *testing.T) *authorizationPlugin {
plugin := &plugins.Plugin{Name: "authz"}
pwd, err := os.Getwd()
if err != nil {
log.Fatal(err)
}
plugin.Client, err = plugins.NewClient("unix:///"+path.Join(pwd, pluginAddress), tlsconfig.Options{InsecureSkipVerify: true})
if err != nil {
t.Fatalf("Failed to create client %v", err)
}
return &authorizationPlugin{name: "plugin", plugin: plugin}
}
// AuthZPluginTestServer is a simple server that implements the authZ plugin interface
type authZPluginTestServer struct {
listener net.Listener
t *testing.T
// request stores the request sent from the daemon to the plugin
recordedRequest Request
// response stores the response sent from the plugin to the daemon
replayResponse Response
}
// start starts the test server that implements the plugin
func (t *authZPluginTestServer) start() {
r := mux.NewRouter()
os.Remove(pluginAddress)
l, err := net.ListenUnix("unix", &net.UnixAddr{Name: pluginAddress, Net: "unix"})
if err != nil {
t.t.Fatalf("Failed to listen %v", err)
}
t.listener = l
r.HandleFunc("/Plugin.Activate", t.activate)
r.HandleFunc("/"+AuthZApiRequest, t.auth)
r.HandleFunc("/"+AuthZApiResponse, t.auth)
t.listener, err = net.Listen("tcp", pluginAddress)
server := http.Server{Handler: r, Addr: pluginAddress}
server.Serve(l)
}
// stop stops the test server that implements the plugin
func (t *authZPluginTestServer) stop() {
os.Remove(pluginAddress)
if t.listener != nil {
t.listener.Close()
}
}
// auth is a used to record/replay the authentication api messages
func (t *authZPluginTestServer) auth(w http.ResponseWriter, r *http.Request) {
t.recordedRequest = Request{}
defer r.Body.Close()
body, err := ioutil.ReadAll(r.Body)
json.Unmarshal(body, &t.recordedRequest)
b, err := json.Marshal(t.replayResponse)
if err != nil {
log.Fatal(err)
}
w.Write(b)
}
func (t *authZPluginTestServer) activate(w http.ResponseWriter, r *http.Request) {
b, err := json.Marshal(plugins.Manifest{Implements: []string{AuthZApiImplements}})
if err != nil {
log.Fatal(err)
}
w.Write(b)
}

View File

@@ -0,0 +1,83 @@
package authorization
import "github.com/hyperhq/hypercli/pkg/plugins"
// Plugin allows third party plugins to authorize requests and responses
// in the context of docker API
type Plugin interface {
// Name returns the registered plugin name
Name() string
// AuthZRequest authorize the request from the client to the daemon
AuthZRequest(*Request) (*Response, error)
// AuthZResponse authorize the response from the daemon to the client
AuthZResponse(*Request) (*Response, error)
}
// NewPlugins constructs and initialize the authorization plugins based on plugin names
func NewPlugins(names []string) []Plugin {
plugins := []Plugin{}
pluginsMap := make(map[string]struct{})
for _, name := range names {
if _, ok := pluginsMap[name]; ok {
continue
}
pluginsMap[name] = struct{}{}
plugins = append(plugins, newAuthorizationPlugin(name))
}
return plugins
}
// authorizationPlugin is an internal adapter to docker plugin system
type authorizationPlugin struct {
plugin *plugins.Plugin
name string
}
func newAuthorizationPlugin(name string) Plugin {
return &authorizationPlugin{name: name}
}
func (a *authorizationPlugin) Name() string {
return a.name
}
func (a *authorizationPlugin) AuthZRequest(authReq *Request) (*Response, error) {
if err := a.initPlugin(); err != nil {
return nil, err
}
authRes := &Response{}
if err := a.plugin.Client.Call(AuthZApiRequest, authReq, authRes); err != nil {
return nil, err
}
return authRes, nil
}
func (a *authorizationPlugin) AuthZResponse(authReq *Request) (*Response, error) {
if err := a.initPlugin(); err != nil {
return nil, err
}
authRes := &Response{}
if err := a.plugin.Client.Call(AuthZApiResponse, authReq, authRes); err != nil {
return nil, err
}
return authRes, nil
}
// initPlugin initialize the authorization plugin if needed
func (a *authorizationPlugin) initPlugin() error {
// Lazy loading of plugins
if a.plugin == nil {
var err error
a.plugin, err = plugins.Get(a.name, AuthZApiImplements)
if err != nil {
return err
}
}
return nil
}

View File

@@ -0,0 +1,136 @@
package authorization
import (
"bufio"
"bytes"
"encoding/json"
"fmt"
"net"
"net/http"
)
// ResponseModifier allows authorization plugins to read and modify the content of the http.response
type ResponseModifier interface {
http.ResponseWriter
// RawBody returns the current http content
RawBody() []byte
// RawHeaders returns the current content of the http headers
RawHeaders() ([]byte, error)
// StatusCode returns the current status code
StatusCode() int
// OverrideBody replace the body of the HTTP reply
OverrideBody(b []byte)
// OverrideHeader replace the headers of the HTTP reply
OverrideHeader(b []byte) error
// OverrideStatusCode replaces the status code of the HTTP reply
OverrideStatusCode(statusCode int)
// Flush flushes all data to the HTTP response
Flush() error
}
// NewResponseModifier creates a wrapper to an http.ResponseWriter to allow inspecting and modifying the content
func NewResponseModifier(rw http.ResponseWriter) ResponseModifier {
return &responseModifier{rw: rw, header: make(http.Header)}
}
// responseModifier is used as an adapter to http.ResponseWriter in order to manipulate and explore
// the http request/response from docker daemon
type responseModifier struct {
// The original response writer
rw http.ResponseWriter
status int
// body holds the response body
body []byte
// header holds the response header
header http.Header
// statusCode holds the response status code
statusCode int
}
// WriterHeader stores the http status code
func (rm *responseModifier) WriteHeader(s int) {
rm.statusCode = s
}
// Header returns the internal http header
func (rm *responseModifier) Header() http.Header {
return rm.header
}
// Header returns the internal http header
func (rm *responseModifier) StatusCode() int {
return rm.statusCode
}
// Override replace the body of the HTTP reply
func (rm *responseModifier) OverrideBody(b []byte) {
rm.body = b
}
func (rm *responseModifier) OverrideStatusCode(statusCode int) {
rm.statusCode = statusCode
}
// Override replace the headers of the HTTP reply
func (rm *responseModifier) OverrideHeader(b []byte) error {
header := http.Header{}
if err := json.Unmarshal(b, &header); err != nil {
return err
}
rm.header = header
return nil
}
// Write stores the byte array inside content
func (rm *responseModifier) Write(b []byte) (int, error) {
rm.body = append(rm.body, b...)
return len(b), nil
}
// Body returns the response body
func (rm *responseModifier) RawBody() []byte {
return rm.body
}
func (rm *responseModifier) RawHeaders() ([]byte, error) {
var b bytes.Buffer
if err := rm.header.Write(&b); err != nil {
return nil, err
}
return b.Bytes(), nil
}
// Hijack returns the internal connection of the wrapped http.ResponseWriter
func (rm *responseModifier) Hijack() (net.Conn, *bufio.ReadWriter, error) {
hijacker, ok := rm.rw.(http.Hijacker)
if !ok {
return nil, nil, fmt.Errorf("Internal reponse writer doesn't support the Hijacker interface")
}
return hijacker.Hijack()
}
// Flush flushes all data to the HTTP response
func (rm *responseModifier) Flush() error {
// Copy the status code
if rm.statusCode > 0 {
rm.rw.WriteHeader(rm.statusCode)
}
// Copy the header
for k, vv := range rm.header {
for _, v := range vv {
rm.rw.Header().Add(k, v)
}
}
// Write body
_, err := rm.rw.Write(rm.body)
return err
}

View File

@@ -0,0 +1,49 @@
package broadcaster
import (
"io"
"sync"
)
// Unbuffered accumulates multiple io.WriteCloser by stream.
type Unbuffered struct {
mu sync.Mutex
writers []io.WriteCloser
}
// Add adds new io.WriteCloser.
func (w *Unbuffered) Add(writer io.WriteCloser) {
w.mu.Lock()
w.writers = append(w.writers, writer)
w.mu.Unlock()
}
// Write writes bytes to all writers. Failed writers will be evicted during
// this call.
func (w *Unbuffered) Write(p []byte) (n int, err error) {
w.mu.Lock()
var evict []int
for i, sw := range w.writers {
if n, err := sw.Write(p); err != nil || n != len(p) {
// On error, evict the writer
evict = append(evict, i)
}
}
for n, i := range evict {
w.writers = append(w.writers[:i-n], w.writers[i-n+1:]...)
}
w.mu.Unlock()
return len(p), nil
}
// Clean closes and removes all writers. Last non-eol-terminated part of data
// will be saved.
func (w *Unbuffered) Clean() error {
w.mu.Lock()
for _, sw := range w.writers {
sw.Close()
}
w.writers = nil
w.mu.Unlock()
return nil
}

View File

@@ -0,0 +1,162 @@
package broadcaster
import (
"bytes"
"errors"
"strings"
"testing"
)
type dummyWriter struct {
buffer bytes.Buffer
failOnWrite bool
}
func (dw *dummyWriter) Write(p []byte) (n int, err error) {
if dw.failOnWrite {
return 0, errors.New("Fake fail")
}
return dw.buffer.Write(p)
}
func (dw *dummyWriter) String() string {
return dw.buffer.String()
}
func (dw *dummyWriter) Close() error {
return nil
}
func TestUnbuffered(t *testing.T) {
writer := new(Unbuffered)
// Test 1: Both bufferA and bufferB should contain "foo"
bufferA := &dummyWriter{}
writer.Add(bufferA)
bufferB := &dummyWriter{}
writer.Add(bufferB)
writer.Write([]byte("foo"))
if bufferA.String() != "foo" {
t.Errorf("Buffer contains %v", bufferA.String())
}
if bufferB.String() != "foo" {
t.Errorf("Buffer contains %v", bufferB.String())
}
// Test2: bufferA and bufferB should contain "foobar",
// while bufferC should only contain "bar"
bufferC := &dummyWriter{}
writer.Add(bufferC)
writer.Write([]byte("bar"))
if bufferA.String() != "foobar" {
t.Errorf("Buffer contains %v", bufferA.String())
}
if bufferB.String() != "foobar" {
t.Errorf("Buffer contains %v", bufferB.String())
}
if bufferC.String() != "bar" {
t.Errorf("Buffer contains %v", bufferC.String())
}
// Test3: Test eviction on failure
bufferA.failOnWrite = true
writer.Write([]byte("fail"))
if bufferA.String() != "foobar" {
t.Errorf("Buffer contains %v", bufferA.String())
}
if bufferC.String() != "barfail" {
t.Errorf("Buffer contains %v", bufferC.String())
}
// Even though we reset the flag, no more writes should go in there
bufferA.failOnWrite = false
writer.Write([]byte("test"))
if bufferA.String() != "foobar" {
t.Errorf("Buffer contains %v", bufferA.String())
}
if bufferC.String() != "barfailtest" {
t.Errorf("Buffer contains %v", bufferC.String())
}
// Test4: Test eviction on multiple simultaneous failures
bufferB.failOnWrite = true
bufferC.failOnWrite = true
bufferD := &dummyWriter{}
writer.Add(bufferD)
writer.Write([]byte("yo"))
writer.Write([]byte("ink"))
if strings.Contains(bufferB.String(), "yoink") {
t.Errorf("bufferB received write. contents: %q", bufferB)
}
if strings.Contains(bufferC.String(), "yoink") {
t.Errorf("bufferC received write. contents: %q", bufferC)
}
if g, w := bufferD.String(), "yoink"; g != w {
t.Errorf("bufferD = %q, want %q", g, w)
}
writer.Clean()
}
type devNullCloser int
func (d devNullCloser) Close() error {
return nil
}
func (d devNullCloser) Write(buf []byte) (int, error) {
return len(buf), nil
}
// This test checks for races. It is only useful when run with the race detector.
func TestRaceUnbuffered(t *testing.T) {
writer := new(Unbuffered)
c := make(chan bool)
go func() {
writer.Add(devNullCloser(0))
c <- true
}()
writer.Write([]byte("hello"))
<-c
}
func BenchmarkUnbuffered(b *testing.B) {
writer := new(Unbuffered)
setUpWriter := func() {
for i := 0; i < 100; i++ {
writer.Add(devNullCloser(0))
writer.Add(devNullCloser(0))
writer.Add(devNullCloser(0))
}
}
testLine := "Line that thinks that it is log line from docker"
var buf bytes.Buffer
for i := 0; i < 100; i++ {
buf.Write([]byte(testLine + "\n"))
}
// line without eol
buf.Write([]byte(testLine))
testText := buf.Bytes()
b.SetBytes(int64(5 * len(testText)))
b.ResetTimer()
for i := 0; i < b.N; i++ {
b.StopTimer()
setUpWriter()
b.StartTimer()
for j := 0; j < 5; j++ {
if _, err := writer.Write(testText); err != nil {
b.Fatal(err)
}
}
b.StopTimer()
writer.Clean()
b.StartTimer()
}
}

View File

@@ -0,0 +1,97 @@
package chrootarchive
import (
"fmt"
"io"
"io/ioutil"
"os"
"path/filepath"
"github.com/hyperhq/hypercli/pkg/archive"
"github.com/hyperhq/hypercli/pkg/idtools"
)
var chrootArchiver = &archive.Archiver{Untar: Untar}
// Untar reads a stream of bytes from `archive`, parses it as a tar archive,
// and unpacks it into the directory at `dest`.
// The archive may be compressed with one of the following algorithms:
// identity (uncompressed), gzip, bzip2, xz.
func Untar(tarArchive io.Reader, dest string, options *archive.TarOptions) error {
return untarHandler(tarArchive, dest, options, true)
}
// UntarUncompressed reads a stream of bytes from `archive`, parses it as a tar archive,
// and unpacks it into the directory at `dest`.
// The archive must be an uncompressed stream.
func UntarUncompressed(tarArchive io.Reader, dest string, options *archive.TarOptions) error {
return untarHandler(tarArchive, dest, options, false)
}
// Handler for teasing out the automatic decompression
func untarHandler(tarArchive io.Reader, dest string, options *archive.TarOptions, decompress bool) error {
if tarArchive == nil {
return fmt.Errorf("Empty archive")
}
if options == nil {
options = &archive.TarOptions{}
}
if options.ExcludePatterns == nil {
options.ExcludePatterns = []string{}
}
rootUID, rootGID, err := idtools.GetRootUIDGID(options.UIDMaps, options.GIDMaps)
if err != nil {
return err
}
dest = filepath.Clean(dest)
if _, err := os.Stat(dest); os.IsNotExist(err) {
if err := idtools.MkdirAllNewAs(dest, 0755, rootUID, rootGID); err != nil {
return err
}
}
r := ioutil.NopCloser(tarArchive)
if decompress {
decompressedArchive, err := archive.DecompressStream(tarArchive)
if err != nil {
return err
}
defer decompressedArchive.Close()
r = decompressedArchive
}
return invokeUnpack(r, dest, options)
}
// TarUntar is a convenience function which calls Tar and Untar, with the output of one piped into the other.
// If either Tar or Untar fails, TarUntar aborts and returns the error.
func TarUntar(src, dst string) error {
return chrootArchiver.TarUntar(src, dst)
}
// CopyWithTar creates a tar archive of filesystem path `src`, and
// unpacks it at filesystem path `dst`.
// The archive is streamed directly with fixed buffering and no
// intermediary disk IO.
func CopyWithTar(src, dst string) error {
return chrootArchiver.CopyWithTar(src, dst)
}
// CopyFileWithTar emulates the behavior of the 'cp' command-line
// for a single file. It copies a regular file from path `src` to
// path `dst`, and preserves all its metadata.
//
// If `dst` ends with a trailing slash '/' ('\' on Windows), the final
// destination path will be `dst/base(src)` or `dst\base(src)`
func CopyFileWithTar(src, dst string) (err error) {
return chrootArchiver.CopyFileWithTar(src, dst)
}
// UntarPath is a convenience function which looks for an archive
// at filesystem path `src`, and unpacks it at `dst`.
func UntarPath(src, dst string) error {
return chrootArchiver.UntarPath(src, dst)
}

View File

@@ -0,0 +1,381 @@
package chrootarchive
import (
"bytes"
"fmt"
"hash/crc32"
"io"
"io/ioutil"
"os"
"path/filepath"
"strings"
"testing"
"time"
"github.com/hyperhq/hypercli/pkg/archive"
"github.com/hyperhq/hypercli/pkg/reexec"
"github.com/hyperhq/hypercli/pkg/system"
)
func init() {
reexec.Init()
}
func TestChrootTarUntar(t *testing.T) {
tmpdir, err := ioutil.TempDir("", "docker-TestChrootTarUntar")
if err != nil {
t.Fatal(err)
}
defer os.RemoveAll(tmpdir)
src := filepath.Join(tmpdir, "src")
if err := system.MkdirAll(src, 0700); err != nil {
t.Fatal(err)
}
if err := ioutil.WriteFile(filepath.Join(src, "toto"), []byte("hello toto"), 0644); err != nil {
t.Fatal(err)
}
if err := ioutil.WriteFile(filepath.Join(src, "lolo"), []byte("hello lolo"), 0644); err != nil {
t.Fatal(err)
}
stream, err := archive.Tar(src, archive.Uncompressed)
if err != nil {
t.Fatal(err)
}
dest := filepath.Join(tmpdir, "src")
if err := system.MkdirAll(dest, 0700); err != nil {
t.Fatal(err)
}
if err := Untar(stream, dest, &archive.TarOptions{ExcludePatterns: []string{"lolo"}}); err != nil {
t.Fatal(err)
}
}
// gh#10426: Verify the fix for having a huge excludes list (like on `docker load` with large # of
// local images)
func TestChrootUntarWithHugeExcludesList(t *testing.T) {
tmpdir, err := ioutil.TempDir("", "docker-TestChrootUntarHugeExcludes")
if err != nil {
t.Fatal(err)
}
defer os.RemoveAll(tmpdir)
src := filepath.Join(tmpdir, "src")
if err := system.MkdirAll(src, 0700); err != nil {
t.Fatal(err)
}
if err := ioutil.WriteFile(filepath.Join(src, "toto"), []byte("hello toto"), 0644); err != nil {
t.Fatal(err)
}
stream, err := archive.Tar(src, archive.Uncompressed)
if err != nil {
t.Fatal(err)
}
dest := filepath.Join(tmpdir, "dest")
if err := system.MkdirAll(dest, 0700); err != nil {
t.Fatal(err)
}
options := &archive.TarOptions{}
//65534 entries of 64-byte strings ~= 4MB of environment space which should overflow
//on most systems when passed via environment or command line arguments
excludes := make([]string, 65534, 65534)
for i := 0; i < 65534; i++ {
excludes[i] = strings.Repeat(string(i), 64)
}
options.ExcludePatterns = excludes
if err := Untar(stream, dest, options); err != nil {
t.Fatal(err)
}
}
func TestChrootUntarEmptyArchive(t *testing.T) {
tmpdir, err := ioutil.TempDir("", "docker-TestChrootUntarEmptyArchive")
if err != nil {
t.Fatal(err)
}
defer os.RemoveAll(tmpdir)
if err := Untar(nil, tmpdir, nil); err == nil {
t.Fatal("expected error on empty archive")
}
}
func prepareSourceDirectory(numberOfFiles int, targetPath string, makeSymLinks bool) (int, error) {
fileData := []byte("fooo")
for n := 0; n < numberOfFiles; n++ {
fileName := fmt.Sprintf("file-%d", n)
if err := ioutil.WriteFile(filepath.Join(targetPath, fileName), fileData, 0700); err != nil {
return 0, err
}
if makeSymLinks {
if err := os.Symlink(filepath.Join(targetPath, fileName), filepath.Join(targetPath, fileName+"-link")); err != nil {
return 0, err
}
}
}
totalSize := numberOfFiles * len(fileData)
return totalSize, nil
}
func getHash(filename string) (uint32, error) {
stream, err := ioutil.ReadFile(filename)
if err != nil {
return 0, err
}
hash := crc32.NewIEEE()
hash.Write(stream)
return hash.Sum32(), nil
}
func compareDirectories(src string, dest string) error {
changes, err := archive.ChangesDirs(dest, src)
if err != nil {
return err
}
if len(changes) > 0 {
return fmt.Errorf("Unexpected differences after untar: %v", changes)
}
return nil
}
func compareFiles(src string, dest string) error {
srcHash, err := getHash(src)
if err != nil {
return err
}
destHash, err := getHash(dest)
if err != nil {
return err
}
if srcHash != destHash {
return fmt.Errorf("%s is different from %s", src, dest)
}
return nil
}
func TestChrootTarUntarWithSymlink(t *testing.T) {
tmpdir, err := ioutil.TempDir("", "docker-TestChrootTarUntarWithSymlink")
if err != nil {
t.Fatal(err)
}
defer os.RemoveAll(tmpdir)
src := filepath.Join(tmpdir, "src")
if err := system.MkdirAll(src, 0700); err != nil {
t.Fatal(err)
}
if _, err := prepareSourceDirectory(10, src, true); err != nil {
t.Fatal(err)
}
dest := filepath.Join(tmpdir, "dest")
if err := TarUntar(src, dest); err != nil {
t.Fatal(err)
}
if err := compareDirectories(src, dest); err != nil {
t.Fatal(err)
}
}
func TestChrootCopyWithTar(t *testing.T) {
tmpdir, err := ioutil.TempDir("", "docker-TestChrootCopyWithTar")
if err != nil {
t.Fatal(err)
}
defer os.RemoveAll(tmpdir)
src := filepath.Join(tmpdir, "src")
if err := system.MkdirAll(src, 0700); err != nil {
t.Fatal(err)
}
if _, err := prepareSourceDirectory(10, src, true); err != nil {
t.Fatal(err)
}
// Copy directory
dest := filepath.Join(tmpdir, "dest")
if err := CopyWithTar(src, dest); err != nil {
t.Fatal(err)
}
if err := compareDirectories(src, dest); err != nil {
t.Fatal(err)
}
// Copy file
srcfile := filepath.Join(src, "file-1")
dest = filepath.Join(tmpdir, "destFile")
destfile := filepath.Join(dest, "file-1")
if err := CopyWithTar(srcfile, destfile); err != nil {
t.Fatal(err)
}
if err := compareFiles(srcfile, destfile); err != nil {
t.Fatal(err)
}
// Copy symbolic link
srcLinkfile := filepath.Join(src, "file-1-link")
dest = filepath.Join(tmpdir, "destSymlink")
destLinkfile := filepath.Join(dest, "file-1-link")
if err := CopyWithTar(srcLinkfile, destLinkfile); err != nil {
t.Fatal(err)
}
if err := compareFiles(srcLinkfile, destLinkfile); err != nil {
t.Fatal(err)
}
}
func TestChrootCopyFileWithTar(t *testing.T) {
tmpdir, err := ioutil.TempDir("", "docker-TestChrootCopyFileWithTar")
if err != nil {
t.Fatal(err)
}
defer os.RemoveAll(tmpdir)
src := filepath.Join(tmpdir, "src")
if err := system.MkdirAll(src, 0700); err != nil {
t.Fatal(err)
}
if _, err := prepareSourceDirectory(10, src, true); err != nil {
t.Fatal(err)
}
// Copy directory
dest := filepath.Join(tmpdir, "dest")
if err := CopyFileWithTar(src, dest); err == nil {
t.Fatal("Expected error on copying directory")
}
// Copy file
srcfile := filepath.Join(src, "file-1")
dest = filepath.Join(tmpdir, "destFile")
destfile := filepath.Join(dest, "file-1")
if err := CopyFileWithTar(srcfile, destfile); err != nil {
t.Fatal(err)
}
if err := compareFiles(srcfile, destfile); err != nil {
t.Fatal(err)
}
// Copy symbolic link
srcLinkfile := filepath.Join(src, "file-1-link")
dest = filepath.Join(tmpdir, "destSymlink")
destLinkfile := filepath.Join(dest, "file-1-link")
if err := CopyFileWithTar(srcLinkfile, destLinkfile); err != nil {
t.Fatal(err)
}
if err := compareFiles(srcLinkfile, destLinkfile); err != nil {
t.Fatal(err)
}
}
func TestChrootUntarPath(t *testing.T) {
tmpdir, err := ioutil.TempDir("", "docker-TestChrootUntarPath")
if err != nil {
t.Fatal(err)
}
defer os.RemoveAll(tmpdir)
src := filepath.Join(tmpdir, "src")
if err := system.MkdirAll(src, 0700); err != nil {
t.Fatal(err)
}
if _, err := prepareSourceDirectory(10, src, true); err != nil {
t.Fatal(err)
}
dest := filepath.Join(tmpdir, "dest")
// Untar a directory
if err := UntarPath(src, dest); err == nil {
t.Fatal("Expected error on untaring a directory")
}
// Untar a tar file
stream, err := archive.Tar(src, archive.Uncompressed)
if err != nil {
t.Fatal(err)
}
buf := new(bytes.Buffer)
buf.ReadFrom(stream)
tarfile := filepath.Join(tmpdir, "src.tar")
if err := ioutil.WriteFile(tarfile, buf.Bytes(), 0644); err != nil {
t.Fatal(err)
}
if err := UntarPath(tarfile, dest); err != nil {
t.Fatal(err)
}
if err := compareDirectories(src, dest); err != nil {
t.Fatal(err)
}
}
type slowEmptyTarReader struct {
size int
offset int
chunkSize int
}
// Read is a slow reader of an empty tar (like the output of "tar c --files-from /dev/null")
func (s *slowEmptyTarReader) Read(p []byte) (int, error) {
time.Sleep(100 * time.Millisecond)
count := s.chunkSize
if len(p) < s.chunkSize {
count = len(p)
}
for i := 0; i < count; i++ {
p[i] = 0
}
s.offset += count
if s.offset > s.size {
return count, io.EOF
}
return count, nil
}
func TestChrootUntarEmptyArchiveFromSlowReader(t *testing.T) {
tmpdir, err := ioutil.TempDir("", "docker-TestChrootUntarEmptyArchiveFromSlowReader")
if err != nil {
t.Fatal(err)
}
defer os.RemoveAll(tmpdir)
dest := filepath.Join(tmpdir, "dest")
if err := system.MkdirAll(dest, 0700); err != nil {
t.Fatal(err)
}
stream := &slowEmptyTarReader{size: 10240, chunkSize: 1024}
if err := Untar(stream, dest, nil); err != nil {
t.Fatal(err)
}
}
func TestChrootApplyEmptyArchiveFromSlowReader(t *testing.T) {
tmpdir, err := ioutil.TempDir("", "docker-TestChrootApplyEmptyArchiveFromSlowReader")
if err != nil {
t.Fatal(err)
}
defer os.RemoveAll(tmpdir)
dest := filepath.Join(tmpdir, "dest")
if err := system.MkdirAll(dest, 0700); err != nil {
t.Fatal(err)
}
stream := &slowEmptyTarReader{size: 10240, chunkSize: 1024}
if _, err := ApplyLayer(dest, stream); err != nil {
t.Fatal(err)
}
}
func TestChrootApplyDotDotFile(t *testing.T) {
tmpdir, err := ioutil.TempDir("", "docker-TestChrootApplyDotDotFile")
if err != nil {
t.Fatal(err)
}
defer os.RemoveAll(tmpdir)
src := filepath.Join(tmpdir, "src")
if err := system.MkdirAll(src, 0700); err != nil {
t.Fatal(err)
}
if err := ioutil.WriteFile(filepath.Join(src, "..gitme"), []byte(""), 0644); err != nil {
t.Fatal(err)
}
stream, err := archive.Tar(src, archive.Uncompressed)
if err != nil {
t.Fatal(err)
}
dest := filepath.Join(tmpdir, "dest")
if err := system.MkdirAll(dest, 0700); err != nil {
t.Fatal(err)
}
if _, err := ApplyLayer(dest, stream); err != nil {
t.Fatal(err)
}
}

View File

@@ -0,0 +1,91 @@
// +build !windows
package chrootarchive
import (
"bytes"
"encoding/json"
"flag"
"fmt"
"io"
"io/ioutil"
"os"
"runtime"
"syscall"
"github.com/hyperhq/hypercli/pkg/archive"
"github.com/hyperhq/hypercli/pkg/reexec"
)
func chroot(path string) error {
if err := syscall.Chroot(path); err != nil {
return err
}
return syscall.Chdir("/")
}
// untar is the entry-point for docker-untar on re-exec. This is not used on
// Windows as it does not support chroot, hence no point sandboxing through
// chroot and rexec.
func untar() {
runtime.LockOSThread()
flag.Parse()
var options *archive.TarOptions
//read the options from the pipe "ExtraFiles"
if err := json.NewDecoder(os.NewFile(3, "options")).Decode(&options); err != nil {
fatal(err)
}
if err := chroot(flag.Arg(0)); err != nil {
fatal(err)
}
if err := archive.Unpack(os.Stdin, "/", options); err != nil {
fatal(err)
}
// fully consume stdin in case it is zero padded
flush(os.Stdin)
os.Exit(0)
}
func invokeUnpack(decompressedArchive io.Reader, dest string, options *archive.TarOptions) error {
// We can't pass a potentially large exclude list directly via cmd line
// because we easily overrun the kernel's max argument/environment size
// when the full image list is passed (e.g. when this is used by
// `docker load`). We will marshall the options via a pipe to the
// child
r, w, err := os.Pipe()
if err != nil {
return fmt.Errorf("Untar pipe failure: %v", err)
}
cmd := reexec.Command("docker-untar", dest)
cmd.Stdin = decompressedArchive
cmd.ExtraFiles = append(cmd.ExtraFiles, r)
output := bytes.NewBuffer(nil)
cmd.Stdout = output
cmd.Stderr = output
if err := cmd.Start(); err != nil {
return fmt.Errorf("Untar error on re-exec cmd: %v", err)
}
//write the options to the pipe for the untar exec to read
if err := json.NewEncoder(w).Encode(options); err != nil {
return fmt.Errorf("Untar json encode to pipe failed: %v", err)
}
w.Close()
if err := cmd.Wait(); err != nil {
// when `xz -d -c -q | docker-untar ...` failed on docker-untar side,
// we need to exhaust `xz`'s output, otherwise the `xz` side will be
// pending on write pipe forever
io.Copy(ioutil.Discard, decompressedArchive)
return fmt.Errorf("Untar re-exec error: %v: output: %s", err, output)
}
return nil
}

View File

@@ -0,0 +1,22 @@
package chrootarchive
import (
"io"
"github.com/hyperhq/hypercli/pkg/archive"
"github.com/hyperhq/hypercli/pkg/longpath"
)
// chroot is not supported by Windows
func chroot(path string) error {
return nil
}
func invokeUnpack(decompressedArchive io.ReadCloser,
dest string,
options *archive.TarOptions) error {
// Windows is different to Linux here because Windows does not support
// chroot. Hence there is no point sandboxing a chrooted process to
// do the unpack. We call inline instead within the daemon process.
return archive.Unpack(decompressedArchive, longpath.AddPrefix(dest), options)
}

View File

@@ -0,0 +1,19 @@
package chrootarchive
import "github.com/hyperhq/hypercli/pkg/archive"
// ApplyLayer parses a diff in the standard layer format from `layer`,
// and applies it to the directory `dest`. The stream `layer` can only be
// uncompressed.
// Returns the size in bytes of the contents of the layer.
func ApplyLayer(dest string, layer archive.Reader) (size int64, err error) {
return applyLayerHandler(dest, layer, &archive.TarOptions{}, true)
}
// ApplyUncompressedLayer parses a diff in the standard layer format from
// `layer`, and applies it to the directory `dest`. The stream `layer`
// can only be uncompressed.
// Returns the size in bytes of the contents of the layer.
func ApplyUncompressedLayer(dest string, layer archive.Reader, options *archive.TarOptions) (int64, error) {
return applyLayerHandler(dest, layer, options, false)
}

View File

@@ -0,0 +1,118 @@
//+build !windows
package chrootarchive
import (
"bytes"
"encoding/json"
"flag"
"fmt"
"io/ioutil"
"os"
"path/filepath"
"runtime"
"github.com/hyperhq/hypercli/pkg/archive"
"github.com/hyperhq/hypercli/pkg/reexec"
"github.com/hyperhq/hypercli/pkg/system"
)
type applyLayerResponse struct {
LayerSize int64 `json:"layerSize"`
}
// applyLayer is the entry-point for docker-applylayer on re-exec. This is not
// used on Windows as it does not support chroot, hence no point sandboxing
// through chroot and rexec.
func applyLayer() {
var (
tmpDir = ""
err error
options *archive.TarOptions
)
runtime.LockOSThread()
flag.Parse()
if err := chroot(flag.Arg(0)); err != nil {
fatal(err)
}
// We need to be able to set any perms
oldmask, err := system.Umask(0)
defer system.Umask(oldmask)
if err != nil {
fatal(err)
}
if err := json.Unmarshal([]byte(os.Getenv("OPT")), &options); err != nil {
fatal(err)
}
if tmpDir, err = ioutil.TempDir("/", "temp-docker-extract"); err != nil {
fatal(err)
}
os.Setenv("TMPDIR", tmpDir)
size, err := archive.UnpackLayer("/", os.Stdin, options)
os.RemoveAll(tmpDir)
if err != nil {
fatal(err)
}
encoder := json.NewEncoder(os.Stdout)
if err := encoder.Encode(applyLayerResponse{size}); err != nil {
fatal(fmt.Errorf("unable to encode layerSize JSON: %s", err))
}
flush(os.Stdout)
flush(os.Stdin)
os.Exit(0)
}
// applyLayerHandler parses a diff in the standard layer format from `layer`, and
// applies it to the directory `dest`. Returns the size in bytes of the
// contents of the layer.
func applyLayerHandler(dest string, layer archive.Reader, options *archive.TarOptions, decompress bool) (size int64, err error) {
dest = filepath.Clean(dest)
if decompress {
decompressed, err := archive.DecompressStream(layer)
if err != nil {
return 0, err
}
defer decompressed.Close()
layer = decompressed
}
if options == nil {
options = &archive.TarOptions{}
}
if options.ExcludePatterns == nil {
options.ExcludePatterns = []string{}
}
data, err := json.Marshal(options)
if err != nil {
return 0, fmt.Errorf("ApplyLayer json encode: %v", err)
}
cmd := reexec.Command("docker-applyLayer", dest)
cmd.Stdin = layer
cmd.Env = append(cmd.Env, fmt.Sprintf("OPT=%s", data))
outBuf, errBuf := new(bytes.Buffer), new(bytes.Buffer)
cmd.Stdout, cmd.Stderr = outBuf, errBuf
if err = cmd.Run(); err != nil {
return 0, fmt.Errorf("ApplyLayer %s stdout: %s stderr: %s", err, outBuf, errBuf)
}
// Stdout should be a valid JSON struct representing an applyLayerResponse.
response := applyLayerResponse{}
decoder := json.NewDecoder(outBuf)
if err = decoder.Decode(&response); err != nil {
return 0, fmt.Errorf("unable to decode ApplyLayer JSON response: %s", err)
}
return response.LayerSize, nil
}

View File

@@ -0,0 +1,44 @@
package chrootarchive
import (
"fmt"
"io/ioutil"
"os"
"path/filepath"
"github.com/hyperhq/hypercli/pkg/archive"
"github.com/hyperhq/hypercli/pkg/longpath"
)
// applyLayerHandler parses a diff in the standard layer format from `layer`, and
// applies it to the directory `dest`. Returns the size in bytes of the
// contents of the layer.
func applyLayerHandler(dest string, layer archive.Reader, options *archive.TarOptions, decompress bool) (size int64, err error) {
dest = filepath.Clean(dest)
// Ensure it is a Windows-style volume path
dest = longpath.AddPrefix(dest)
if decompress {
decompressed, err := archive.DecompressStream(layer)
if err != nil {
return 0, err
}
defer decompressed.Close()
layer = decompressed
}
tmpDir, err := ioutil.TempDir(os.Getenv("temp"), "temp-docker-extract")
if err != nil {
return 0, fmt.Errorf("ApplyLayer failed to create temp-docker-extract under %s. %s", dest, err)
}
s, err := archive.UnpackLayer(dest, layer, nil)
os.RemoveAll(tmpDir)
if err != nil {
return 0, fmt.Errorf("ApplyLayer %s failed UnpackLayer to %s", err, dest)
}
return s, nil
}

View File

@@ -0,0 +1,28 @@
// +build !windows
package chrootarchive
import (
"fmt"
"io"
"io/ioutil"
"os"
"github.com/hyperhq/hypercli/pkg/reexec"
)
func init() {
reexec.Register("docker-applyLayer", applyLayer)
reexec.Register("docker-untar", untar)
}
func fatal(err error) {
fmt.Fprint(os.Stderr, err)
os.Exit(1)
}
// flush consumes all the bytes from the reader discarding
// any errors
func flush(r io.Reader) {
io.Copy(ioutil.Discard, r)
}

View File

@@ -0,0 +1,4 @@
package chrootarchive
func init() {
}

View File

@@ -0,0 +1,807 @@
// +build linux
package devicemapper
import (
"errors"
"fmt"
"os"
"runtime"
"syscall"
"unsafe"
"github.com/Sirupsen/logrus"
)
// DevmapperLogger defines methods for logging with devicemapper.
type DevmapperLogger interface {
DMLog(level int, file string, line int, dmError int, message string)
}
const (
deviceCreate TaskType = iota
deviceReload
deviceRemove
deviceRemoveAll
deviceSuspend
deviceResume
deviceInfo
deviceDeps
deviceRename
deviceVersion
deviceStatus
deviceTable
deviceWaitevent
deviceList
deviceClear
deviceMknodes
deviceListVersions
deviceTargetMsg
deviceSetGeometry
)
const (
addNodeOnResume AddNodeType = iota
addNodeOnCreate
)
// List of errors returned when using devicemapper.
var (
ErrTaskRun = errors.New("dm_task_run failed")
ErrTaskSetName = errors.New("dm_task_set_name failed")
ErrTaskSetMessage = errors.New("dm_task_set_message failed")
ErrTaskSetAddNode = errors.New("dm_task_set_add_node failed")
ErrTaskSetRo = errors.New("dm_task_set_ro failed")
ErrTaskAddTarget = errors.New("dm_task_add_target failed")
ErrTaskSetSector = errors.New("dm_task_set_sector failed")
ErrTaskGetDeps = errors.New("dm_task_get_deps failed")
ErrTaskGetInfo = errors.New("dm_task_get_info failed")
ErrTaskGetDriverVersion = errors.New("dm_task_get_driver_version failed")
ErrTaskDeferredRemove = errors.New("dm_task_deferred_remove failed")
ErrTaskSetCookie = errors.New("dm_task_set_cookie failed")
ErrNilCookie = errors.New("cookie ptr can't be nil")
ErrGetBlockSize = errors.New("Can't get block size")
ErrUdevWait = errors.New("wait on udev cookie failed")
ErrSetDevDir = errors.New("dm_set_dev_dir failed")
ErrGetLibraryVersion = errors.New("dm_get_library_version failed")
ErrCreateRemoveTask = errors.New("Can't create task of type deviceRemove")
ErrRunRemoveDevice = errors.New("running RemoveDevice failed")
ErrInvalidAddNode = errors.New("Invalid AddNode type")
ErrBusy = errors.New("Device is Busy")
ErrDeviceIDExists = errors.New("Device Id Exists")
ErrEnxio = errors.New("No such device or address")
)
var (
dmSawBusy bool
dmSawExist bool
dmSawEnxio bool // No Such Device or Address
)
type (
// Task represents a devicemapper task (like lvcreate, etc.) ; a task is needed for each ioctl
// command to execute.
Task struct {
unmanaged *cdmTask
}
// Deps represents dependents (layer) of a device.
Deps struct {
Count uint32
Filler uint32
Device []uint64
}
// Info represents information about a device.
Info struct {
Exists int
Suspended int
LiveTable int
InactiveTable int
OpenCount int32
EventNr uint32
Major uint32
Minor uint32
ReadOnly int
TargetCount int32
DeferredRemove int
}
// TaskType represents a type of task
TaskType int
// AddNodeType represents a type of node to be added
AddNodeType int
)
// DeviceIDExists returns whether error conveys the information about device Id already
// exist or not. This will be true if device creation or snap creation
// operation fails if device or snap device already exists in pool.
// Current implementation is little crude as it scans the error string
// for exact pattern match. Replacing it with more robust implementation
// is desirable.
func DeviceIDExists(err error) bool {
return fmt.Sprint(err) == fmt.Sprint(ErrDeviceIDExists)
}
func (t *Task) destroy() {
if t != nil {
DmTaskDestroy(t.unmanaged)
runtime.SetFinalizer(t, nil)
}
}
// TaskCreateNamed is a convenience function for TaskCreate when a name
// will be set on the task as well
func TaskCreateNamed(t TaskType, name string) (*Task, error) {
task := TaskCreate(t)
if task == nil {
return nil, fmt.Errorf("devicemapper: Can't create task of type %d", int(t))
}
if err := task.setName(name); err != nil {
return nil, fmt.Errorf("devicemapper: Can't set task name %s", name)
}
return task, nil
}
// TaskCreate initializes a devicemapper task of tasktype
func TaskCreate(tasktype TaskType) *Task {
Ctask := DmTaskCreate(int(tasktype))
if Ctask == nil {
return nil
}
task := &Task{unmanaged: Ctask}
runtime.SetFinalizer(task, (*Task).destroy)
return task
}
func (t *Task) run() error {
if res := DmTaskRun(t.unmanaged); res != 1 {
return ErrTaskRun
}
return nil
}
func (t *Task) setName(name string) error {
if res := DmTaskSetName(t.unmanaged, name); res != 1 {
return ErrTaskSetName
}
return nil
}
func (t *Task) setMessage(message string) error {
if res := DmTaskSetMessage(t.unmanaged, message); res != 1 {
return ErrTaskSetMessage
}
return nil
}
func (t *Task) setSector(sector uint64) error {
if res := DmTaskSetSector(t.unmanaged, sector); res != 1 {
return ErrTaskSetSector
}
return nil
}
func (t *Task) setCookie(cookie *uint, flags uint16) error {
if cookie == nil {
return ErrNilCookie
}
if res := DmTaskSetCookie(t.unmanaged, cookie, flags); res != 1 {
return ErrTaskSetCookie
}
return nil
}
func (t *Task) setAddNode(addNode AddNodeType) error {
if addNode != addNodeOnResume && addNode != addNodeOnCreate {
return ErrInvalidAddNode
}
if res := DmTaskSetAddNode(t.unmanaged, addNode); res != 1 {
return ErrTaskSetAddNode
}
return nil
}
func (t *Task) setRo() error {
if res := DmTaskSetRo(t.unmanaged); res != 1 {
return ErrTaskSetRo
}
return nil
}
func (t *Task) addTarget(start, size uint64, ttype, params string) error {
if res := DmTaskAddTarget(t.unmanaged, start, size,
ttype, params); res != 1 {
return ErrTaskAddTarget
}
return nil
}
func (t *Task) getDeps() (*Deps, error) {
var deps *Deps
if deps = DmTaskGetDeps(t.unmanaged); deps == nil {
return nil, ErrTaskGetDeps
}
return deps, nil
}
func (t *Task) getInfo() (*Info, error) {
info := &Info{}
if res := DmTaskGetInfo(t.unmanaged, info); res != 1 {
return nil, ErrTaskGetInfo
}
return info, nil
}
func (t *Task) getInfoWithDeferred() (*Info, error) {
info := &Info{}
if res := DmTaskGetInfoWithDeferred(t.unmanaged, info); res != 1 {
return nil, ErrTaskGetInfo
}
return info, nil
}
func (t *Task) getDriverVersion() (string, error) {
res := DmTaskGetDriverVersion(t.unmanaged)
if res == "" {
return "", ErrTaskGetDriverVersion
}
return res, nil
}
func (t *Task) getNextTarget(next unsafe.Pointer) (nextPtr unsafe.Pointer, start uint64,
length uint64, targetType string, params string) {
return DmGetNextTarget(t.unmanaged, next, &start, &length,
&targetType, &params),
start, length, targetType, params
}
// UdevWait waits for any processes that are waiting for udev to complete the specified cookie.
func UdevWait(cookie *uint) error {
if res := DmUdevWait(*cookie); res != 1 {
logrus.Debugf("devicemapper: Failed to wait on udev cookie %d", *cookie)
return ErrUdevWait
}
return nil
}
// LogInitVerbose is an interface to initialize the verbose logger for the device mapper library.
func LogInitVerbose(level int) {
DmLogInitVerbose(level)
}
var dmLogger DevmapperLogger
// LogInit initializes the logger for the device mapper library.
func LogInit(logger DevmapperLogger) {
dmLogger = logger
LogWithErrnoInit()
}
// SetDevDir sets the dev folder for the device mapper library (usually /dev).
func SetDevDir(dir string) error {
if res := DmSetDevDir(dir); res != 1 {
logrus.Debugf("devicemapper: Error dm_set_dev_dir")
return ErrSetDevDir
}
return nil
}
// GetLibraryVersion returns the device mapper library version.
func GetLibraryVersion() (string, error) {
var version string
if res := DmGetLibraryVersion(&version); res != 1 {
return "", ErrGetLibraryVersion
}
return version, nil
}
// UdevSyncSupported returns whether device-mapper is able to sync with udev
//
// This is essential otherwise race conditions can arise where both udev and
// device-mapper attempt to create and destroy devices.
func UdevSyncSupported() bool {
return DmUdevGetSyncSupport() != 0
}
// UdevSetSyncSupport allows setting whether the udev sync should be enabled.
// The return bool indicates the state of whether the sync is enabled.
func UdevSetSyncSupport(enable bool) bool {
if enable {
DmUdevSetSyncSupport(1)
} else {
DmUdevSetSyncSupport(0)
}
return UdevSyncSupported()
}
// CookieSupported returns whether the version of device-mapper supports the
// use of cookie's in the tasks.
// This is largely a lower level call that other functions use.
func CookieSupported() bool {
return DmCookieSupported() != 0
}
// RemoveDevice is a useful helper for cleaning up a device.
func RemoveDevice(name string) error {
task, err := TaskCreateNamed(deviceRemove, name)
if task == nil {
return err
}
var cookie uint
if err := task.setCookie(&cookie, 0); err != nil {
return fmt.Errorf("devicemapper: Can not set cookie: %s", err)
}
defer UdevWait(&cookie)
dmSawBusy = false // reset before the task is run
if err = task.run(); err != nil {
if dmSawBusy {
return ErrBusy
}
return fmt.Errorf("devicemapper: Error running RemoveDevice %s", err)
}
return nil
}
// RemoveDeviceDeferred is a useful helper for cleaning up a device, but deferred.
func RemoveDeviceDeferred(name string) error {
logrus.Debugf("devicemapper: RemoveDeviceDeferred START(%s)", name)
defer logrus.Debugf("devicemapper: RemoveDeviceDeferred END(%s)", name)
task, err := TaskCreateNamed(deviceRemove, name)
if task == nil {
return err
}
if err := DmTaskDeferredRemove(task.unmanaged); err != 1 {
return ErrTaskDeferredRemove
}
if err = task.run(); err != nil {
return fmt.Errorf("devicemapper: Error running RemoveDeviceDeferred %s", err)
}
return nil
}
// CancelDeferredRemove cancels a deferred remove for a device.
func CancelDeferredRemove(deviceName string) error {
task, err := TaskCreateNamed(deviceTargetMsg, deviceName)
if task == nil {
return err
}
if err := task.setSector(0); err != nil {
return fmt.Errorf("devicemapper: Can't set sector %s", err)
}
if err := task.setMessage(fmt.Sprintf("@cancel_deferred_remove")); err != nil {
return fmt.Errorf("devicemapper: Can't set message %s", err)
}
dmSawBusy = false
dmSawEnxio = false
if err := task.run(); err != nil {
// A device might be being deleted already
if dmSawBusy {
return ErrBusy
} else if dmSawEnxio {
return ErrEnxio
}
return fmt.Errorf("devicemapper: Error running CancelDeferredRemove %s", err)
}
return nil
}
// GetBlockDeviceSize returns the size of a block device identified by the specified file.
func GetBlockDeviceSize(file *os.File) (uint64, error) {
size, err := ioctlBlkGetSize64(file.Fd())
if err != nil {
logrus.Errorf("devicemapper: Error getblockdevicesize: %s", err)
return 0, ErrGetBlockSize
}
return uint64(size), nil
}
// BlockDeviceDiscard runs discard for the given path.
// This is used as a workaround for the kernel not discarding block so
// on the thin pool when we remove a thinp device, so we do it
// manually
func BlockDeviceDiscard(path string) error {
file, err := os.OpenFile(path, os.O_RDWR, 0)
if err != nil {
return err
}
defer file.Close()
size, err := GetBlockDeviceSize(file)
if err != nil {
return err
}
if err := ioctlBlkDiscard(file.Fd(), 0, size); err != nil {
return err
}
// Without this sometimes the remove of the device that happens after
// discard fails with EBUSY.
syscall.Sync()
return nil
}
// CreatePool is the programmatic example of "dmsetup create".
// It creates a device with the specified poolName, data and metadata file and block size.
func CreatePool(poolName string, dataFile, metadataFile *os.File, poolBlockSize uint32) error {
task, err := TaskCreateNamed(deviceCreate, poolName)
if task == nil {
return err
}
size, err := GetBlockDeviceSize(dataFile)
if err != nil {
return fmt.Errorf("devicemapper: Can't get data size %s", err)
}
params := fmt.Sprintf("%s %s %d 32768 1 skip_block_zeroing", metadataFile.Name(), dataFile.Name(), poolBlockSize)
if err := task.addTarget(0, size/512, "thin-pool", params); err != nil {
return fmt.Errorf("devicemapper: Can't add target %s", err)
}
var cookie uint
var flags uint16
flags = DmUdevDisableSubsystemRulesFlag | DmUdevDisableDiskRulesFlag | DmUdevDisableOtherRulesFlag
if err := task.setCookie(&cookie, flags); err != nil {
return fmt.Errorf("devicemapper: Can't set cookie %s", err)
}
defer UdevWait(&cookie)
if err := task.run(); err != nil {
return fmt.Errorf("devicemapper: Error running deviceCreate (CreatePool) %s", err)
}
return nil
}
// ReloadPool is the programmatic example of "dmsetup reload".
// It reloads the table with the specified poolName, data and metadata file and block size.
func ReloadPool(poolName string, dataFile, metadataFile *os.File, poolBlockSize uint32) error {
task, err := TaskCreateNamed(deviceReload, poolName)
if task == nil {
return err
}
size, err := GetBlockDeviceSize(dataFile)
if err != nil {
return fmt.Errorf("devicemapper: Can't get data size %s", err)
}
params := fmt.Sprintf("%s %s %d 32768 1 skip_block_zeroing", metadataFile.Name(), dataFile.Name(), poolBlockSize)
if err := task.addTarget(0, size/512, "thin-pool", params); err != nil {
return fmt.Errorf("devicemapper: Can't add target %s", err)
}
if err := task.run(); err != nil {
return fmt.Errorf("devicemapper: Error running deviceCreate %s", err)
}
return nil
}
// GetDeps is the programmatic example of "dmsetup deps".
// It outputs a list of devices referenced by the live table for the specified device.
func GetDeps(name string) (*Deps, error) {
task, err := TaskCreateNamed(deviceDeps, name)
if task == nil {
return nil, err
}
if err := task.run(); err != nil {
return nil, err
}
return task.getDeps()
}
// GetInfo is the programmatic example of "dmsetup info".
// It outputs some brief information about the device.
func GetInfo(name string) (*Info, error) {
task, err := TaskCreateNamed(deviceInfo, name)
if task == nil {
return nil, err
}
if err := task.run(); err != nil {
return nil, err
}
return task.getInfo()
}
// GetInfoWithDeferred is the programmatic example of "dmsetup info", but deferred.
// It outputs some brief information about the device.
func GetInfoWithDeferred(name string) (*Info, error) {
task, err := TaskCreateNamed(deviceInfo, name)
if task == nil {
return nil, err
}
if err := task.run(); err != nil {
return nil, err
}
return task.getInfoWithDeferred()
}
// GetDriverVersion is the programmatic example of "dmsetup version".
// It outputs version information of the driver.
func GetDriverVersion() (string, error) {
task := TaskCreate(deviceVersion)
if task == nil {
return "", fmt.Errorf("devicemapper: Can't create deviceVersion task")
}
if err := task.run(); err != nil {
return "", err
}
return task.getDriverVersion()
}
// GetStatus is the programmatic example of "dmsetup status".
// It outputs status information for the specified device name.
func GetStatus(name string) (uint64, uint64, string, string, error) {
task, err := TaskCreateNamed(deviceStatus, name)
if task == nil {
logrus.Debugf("devicemapper: GetStatus() Error TaskCreateNamed: %s", err)
return 0, 0, "", "", err
}
if err := task.run(); err != nil {
logrus.Debugf("devicemapper: GetStatus() Error Run: %s", err)
return 0, 0, "", "", err
}
devinfo, err := task.getInfo()
if err != nil {
logrus.Debugf("devicemapper: GetStatus() Error GetInfo: %s", err)
return 0, 0, "", "", err
}
if devinfo.Exists == 0 {
logrus.Debugf("devicemapper: GetStatus() Non existing device %s", name)
return 0, 0, "", "", fmt.Errorf("devicemapper: Non existing device %s", name)
}
_, start, length, targetType, params := task.getNextTarget(unsafe.Pointer(nil))
return start, length, targetType, params, nil
}
// GetTable is the programmatic example for "dmsetup table".
// It outputs the current table for the specified device name.
func GetTable(name string) (uint64, uint64, string, string, error) {
task, err := TaskCreateNamed(deviceTable, name)
if task == nil {
logrus.Debugf("devicemapper: GetTable() Error TaskCreateNamed: %s", err)
return 0, 0, "", "", err
}
if err := task.run(); err != nil {
logrus.Debugf("devicemapper: GetTable() Error Run: %s", err)
return 0, 0, "", "", err
}
devinfo, err := task.getInfo()
if err != nil {
logrus.Debugf("devicemapper: GetTable() Error GetInfo: %s", err)
return 0, 0, "", "", err
}
if devinfo.Exists == 0 {
logrus.Debugf("devicemapper: GetTable() Non existing device %s", name)
return 0, 0, "", "", fmt.Errorf("devicemapper: Non existing device %s", name)
}
_, start, length, targetType, params := task.getNextTarget(unsafe.Pointer(nil))
return start, length, targetType, params, nil
}
// SetTransactionID sets a transaction id for the specified device name.
func SetTransactionID(poolName string, oldID uint64, newID uint64) error {
task, err := TaskCreateNamed(deviceTargetMsg, poolName)
if task == nil {
return err
}
if err := task.setSector(0); err != nil {
return fmt.Errorf("devicemapper: Can't set sector %s", err)
}
if err := task.setMessage(fmt.Sprintf("set_transaction_id %d %d", oldID, newID)); err != nil {
return fmt.Errorf("devicemapper: Can't set message %s", err)
}
if err := task.run(); err != nil {
return fmt.Errorf("devicemapper: Error running SetTransactionID %s", err)
}
return nil
}
// SuspendDevice is the programmatic example of "dmsetup suspend".
// It suspends the specified device.
func SuspendDevice(name string) error {
task, err := TaskCreateNamed(deviceSuspend, name)
if task == nil {
return err
}
if err := task.run(); err != nil {
return fmt.Errorf("devicemapper: Error running deviceSuspend %s", err)
}
return nil
}
// ResumeDevice is the programmatic example of "dmsetup resume".
// It un-suspends the specified device.
func ResumeDevice(name string) error {
task, err := TaskCreateNamed(deviceResume, name)
if task == nil {
return err
}
var cookie uint
if err := task.setCookie(&cookie, 0); err != nil {
return fmt.Errorf("devicemapper: Can't set cookie %s", err)
}
defer UdevWait(&cookie)
if err := task.run(); err != nil {
return fmt.Errorf("devicemapper: Error running deviceResume %s", err)
}
return nil
}
// CreateDevice creates a device with the specified poolName with the specified device id.
func CreateDevice(poolName string, deviceID int) error {
logrus.Debugf("devicemapper: CreateDevice(poolName=%v, deviceID=%v)", poolName, deviceID)
task, err := TaskCreateNamed(deviceTargetMsg, poolName)
if task == nil {
return err
}
if err := task.setSector(0); err != nil {
return fmt.Errorf("devicemapper: Can't set sector %s", err)
}
if err := task.setMessage(fmt.Sprintf("create_thin %d", deviceID)); err != nil {
return fmt.Errorf("devicemapper: Can't set message %s", err)
}
dmSawExist = false // reset before the task is run
if err := task.run(); err != nil {
// Caller wants to know about ErrDeviceIDExists so that it can try with a different device id.
if dmSawExist {
return ErrDeviceIDExists
}
return fmt.Errorf("devicemapper: Error running CreateDevice %s", err)
}
return nil
}
// DeleteDevice deletes a device with the specified poolName with the specified device id.
func DeleteDevice(poolName string, deviceID int) error {
task, err := TaskCreateNamed(deviceTargetMsg, poolName)
if task == nil {
return err
}
if err := task.setSector(0); err != nil {
return fmt.Errorf("devicemapper: Can't set sector %s", err)
}
if err := task.setMessage(fmt.Sprintf("delete %d", deviceID)); err != nil {
return fmt.Errorf("devicemapper: Can't set message %s", err)
}
dmSawBusy = false
if err := task.run(); err != nil {
if dmSawBusy {
return ErrBusy
}
return fmt.Errorf("devicemapper: Error running DeleteDevice %s", err)
}
return nil
}
// ActivateDevice activates the device identified by the specified
// poolName, name and deviceID with the specified size.
func ActivateDevice(poolName string, name string, deviceID int, size uint64) error {
return activateDevice(poolName, name, deviceID, size, "")
}
// ActivateDeviceWithExternal activates the device identified by the specified
// poolName, name and deviceID with the specified size.
func ActivateDeviceWithExternal(poolName string, name string, deviceID int, size uint64, external string) error {
return activateDevice(poolName, name, deviceID, size, external)
}
func activateDevice(poolName string, name string, deviceID int, size uint64, external string) error {
task, err := TaskCreateNamed(deviceCreate, name)
if task == nil {
return err
}
var params string
if len(external) > 0 {
params = fmt.Sprintf("%s %d %s", poolName, deviceID, external)
} else {
params = fmt.Sprintf("%s %d", poolName, deviceID)
}
if err := task.addTarget(0, size/512, "thin", params); err != nil {
return fmt.Errorf("devicemapper: Can't add target %s", err)
}
if err := task.setAddNode(addNodeOnCreate); err != nil {
return fmt.Errorf("devicemapper: Can't add node %s", err)
}
var cookie uint
if err := task.setCookie(&cookie, 0); err != nil {
return fmt.Errorf("devicemapper: Can't set cookie %s", err)
}
defer UdevWait(&cookie)
if err := task.run(); err != nil {
return fmt.Errorf("devicemapper: Error running deviceCreate (ActivateDevice) %s", err)
}
return nil
}
// CreateSnapDevice creates a snapshot based on the device identified by the baseName and baseDeviceId,
func CreateSnapDevice(poolName string, deviceID int, baseName string, baseDeviceID int) error {
devinfo, _ := GetInfo(baseName)
doSuspend := devinfo != nil && devinfo.Exists != 0
if doSuspend {
if err := SuspendDevice(baseName); err != nil {
return err
}
}
task, err := TaskCreateNamed(deviceTargetMsg, poolName)
if task == nil {
if doSuspend {
ResumeDevice(baseName)
}
return err
}
if err := task.setSector(0); err != nil {
if doSuspend {
ResumeDevice(baseName)
}
return fmt.Errorf("devicemapper: Can't set sector %s", err)
}
if err := task.setMessage(fmt.Sprintf("create_snap %d %d", deviceID, baseDeviceID)); err != nil {
if doSuspend {
ResumeDevice(baseName)
}
return fmt.Errorf("devicemapper: Can't set message %s", err)
}
dmSawExist = false // reset before the task is run
if err := task.run(); err != nil {
if doSuspend {
ResumeDevice(baseName)
}
// Caller wants to know about ErrDeviceIDExists so that it can try with a different device id.
if dmSawExist {
return ErrDeviceIDExists
}
return fmt.Errorf("devicemapper: Error running deviceCreate (createSnapDevice) %s", err)
}
if doSuspend {
if err := ResumeDevice(baseName); err != nil {
return err
}
}
return nil
}

View File

@@ -0,0 +1,35 @@
// +build linux
package devicemapper
import "C"
import (
"strings"
)
// Due to the way cgo works this has to be in a separate file, as devmapper.go has
// definitions in the cgo block, which is incompatible with using "//export"
// DevmapperLogCallback exports the devmapper log callback for cgo.
//export DevmapperLogCallback
func DevmapperLogCallback(level C.int, file *C.char, line C.int, dmErrnoOrClass C.int, message *C.char) {
msg := C.GoString(message)
if level < 7 {
if strings.Contains(msg, "busy") {
dmSawBusy = true
}
if strings.Contains(msg, "File exists") {
dmSawExist = true
}
if strings.Contains(msg, "No such device or address") {
dmSawEnxio = true
}
}
if dmLogger != nil {
dmLogger.DMLog(int(level), C.GoString(file), int(line), int(dmErrnoOrClass), msg)
}
}

View File

@@ -0,0 +1,251 @@
// +build linux
package devicemapper
/*
#cgo LDFLAGS: -L. -ldevmapper
#include <libdevmapper.h>
#include <linux/fs.h> // FIXME: present only for BLKGETSIZE64, maybe we can remove it?
// FIXME: Can't we find a way to do the logging in pure Go?
extern void DevmapperLogCallback(int level, char *file, int line, int dm_errno_or_class, char *str);
static void log_cb(int level, const char *file, int line, int dm_errno_or_class, const char *f, ...)
{
char buffer[256];
va_list ap;
va_start(ap, f);
vsnprintf(buffer, 256, f, ap);
va_end(ap);
DevmapperLogCallback(level, (char *)file, line, dm_errno_or_class, buffer);
}
static void log_with_errno_init()
{
dm_log_with_errno_init(log_cb);
}
*/
import "C"
import (
"reflect"
"unsafe"
)
type (
cdmTask C.struct_dm_task
)
// IOCTL consts
const (
BlkGetSize64 = C.BLKGETSIZE64
BlkDiscard = C.BLKDISCARD
)
// Devicemapper cookie flags.
const (
DmUdevDisableSubsystemRulesFlag = C.DM_UDEV_DISABLE_SUBSYSTEM_RULES_FLAG
DmUdevDisableDiskRulesFlag = C.DM_UDEV_DISABLE_DISK_RULES_FLAG
DmUdevDisableOtherRulesFlag = C.DM_UDEV_DISABLE_OTHER_RULES_FLAG
DmUdevDisableLibraryFallback = C.DM_UDEV_DISABLE_LIBRARY_FALLBACK
)
// DeviceMapper mapped functions.
var (
DmGetLibraryVersion = dmGetLibraryVersionFct
DmGetNextTarget = dmGetNextTargetFct
DmLogInitVerbose = dmLogInitVerboseFct
DmSetDevDir = dmSetDevDirFct
DmTaskAddTarget = dmTaskAddTargetFct
DmTaskCreate = dmTaskCreateFct
DmTaskDestroy = dmTaskDestroyFct
DmTaskGetDeps = dmTaskGetDepsFct
DmTaskGetInfo = dmTaskGetInfoFct
DmTaskGetDriverVersion = dmTaskGetDriverVersionFct
DmTaskRun = dmTaskRunFct
DmTaskSetAddNode = dmTaskSetAddNodeFct
DmTaskSetCookie = dmTaskSetCookieFct
DmTaskSetMessage = dmTaskSetMessageFct
DmTaskSetName = dmTaskSetNameFct
DmTaskSetRo = dmTaskSetRoFct
DmTaskSetSector = dmTaskSetSectorFct
DmUdevWait = dmUdevWaitFct
DmUdevSetSyncSupport = dmUdevSetSyncSupportFct
DmUdevGetSyncSupport = dmUdevGetSyncSupportFct
DmCookieSupported = dmCookieSupportedFct
LogWithErrnoInit = logWithErrnoInitFct
DmTaskDeferredRemove = dmTaskDeferredRemoveFct
DmTaskGetInfoWithDeferred = dmTaskGetInfoWithDeferredFct
)
func free(p *C.char) {
C.free(unsafe.Pointer(p))
}
func dmTaskDestroyFct(task *cdmTask) {
C.dm_task_destroy((*C.struct_dm_task)(task))
}
func dmTaskCreateFct(taskType int) *cdmTask {
return (*cdmTask)(C.dm_task_create(C.int(taskType)))
}
func dmTaskRunFct(task *cdmTask) int {
ret, _ := C.dm_task_run((*C.struct_dm_task)(task))
return int(ret)
}
func dmTaskSetNameFct(task *cdmTask, name string) int {
Cname := C.CString(name)
defer free(Cname)
return int(C.dm_task_set_name((*C.struct_dm_task)(task), Cname))
}
func dmTaskSetMessageFct(task *cdmTask, message string) int {
Cmessage := C.CString(message)
defer free(Cmessage)
return int(C.dm_task_set_message((*C.struct_dm_task)(task), Cmessage))
}
func dmTaskSetSectorFct(task *cdmTask, sector uint64) int {
return int(C.dm_task_set_sector((*C.struct_dm_task)(task), C.uint64_t(sector)))
}
func dmTaskSetCookieFct(task *cdmTask, cookie *uint, flags uint16) int {
cCookie := C.uint32_t(*cookie)
defer func() {
*cookie = uint(cCookie)
}()
return int(C.dm_task_set_cookie((*C.struct_dm_task)(task), &cCookie, C.uint16_t(flags)))
}
func dmTaskSetAddNodeFct(task *cdmTask, addNode AddNodeType) int {
return int(C.dm_task_set_add_node((*C.struct_dm_task)(task), C.dm_add_node_t(addNode)))
}
func dmTaskSetRoFct(task *cdmTask) int {
return int(C.dm_task_set_ro((*C.struct_dm_task)(task)))
}
func dmTaskAddTargetFct(task *cdmTask,
start, size uint64, ttype, params string) int {
Cttype := C.CString(ttype)
defer free(Cttype)
Cparams := C.CString(params)
defer free(Cparams)
return int(C.dm_task_add_target((*C.struct_dm_task)(task), C.uint64_t(start), C.uint64_t(size), Cttype, Cparams))
}
func dmTaskGetDepsFct(task *cdmTask) *Deps {
Cdeps := C.dm_task_get_deps((*C.struct_dm_task)(task))
if Cdeps == nil {
return nil
}
// golang issue: https://github.com/golang/go/issues/11925
hdr := reflect.SliceHeader{
Data: uintptr(unsafe.Pointer(uintptr(unsafe.Pointer(Cdeps)) + unsafe.Sizeof(*Cdeps))),
Len: int(Cdeps.count),
Cap: int(Cdeps.count),
}
devices := *(*[]C.uint64_t)(unsafe.Pointer(&hdr))
deps := &Deps{
Count: uint32(Cdeps.count),
Filler: uint32(Cdeps.filler),
}
for _, device := range devices {
deps.Device = append(deps.Device, uint64(device))
}
return deps
}
func dmTaskGetInfoFct(task *cdmTask, info *Info) int {
Cinfo := C.struct_dm_info{}
defer func() {
info.Exists = int(Cinfo.exists)
info.Suspended = int(Cinfo.suspended)
info.LiveTable = int(Cinfo.live_table)
info.InactiveTable = int(Cinfo.inactive_table)
info.OpenCount = int32(Cinfo.open_count)
info.EventNr = uint32(Cinfo.event_nr)
info.Major = uint32(Cinfo.major)
info.Minor = uint32(Cinfo.minor)
info.ReadOnly = int(Cinfo.read_only)
info.TargetCount = int32(Cinfo.target_count)
}()
return int(C.dm_task_get_info((*C.struct_dm_task)(task), &Cinfo))
}
func dmTaskGetDriverVersionFct(task *cdmTask) string {
buffer := C.malloc(128)
defer C.free(buffer)
res := C.dm_task_get_driver_version((*C.struct_dm_task)(task), (*C.char)(buffer), 128)
if res == 0 {
return ""
}
return C.GoString((*C.char)(buffer))
}
func dmGetNextTargetFct(task *cdmTask, next unsafe.Pointer, start, length *uint64, target, params *string) unsafe.Pointer {
var (
Cstart, Clength C.uint64_t
CtargetType, Cparams *C.char
)
defer func() {
*start = uint64(Cstart)
*length = uint64(Clength)
*target = C.GoString(CtargetType)
*params = C.GoString(Cparams)
}()
nextp := C.dm_get_next_target((*C.struct_dm_task)(task), next, &Cstart, &Clength, &CtargetType, &Cparams)
return nextp
}
func dmUdevSetSyncSupportFct(syncWithUdev int) {
(C.dm_udev_set_sync_support(C.int(syncWithUdev)))
}
func dmUdevGetSyncSupportFct() int {
return int(C.dm_udev_get_sync_support())
}
func dmUdevWaitFct(cookie uint) int {
return int(C.dm_udev_wait(C.uint32_t(cookie)))
}
func dmCookieSupportedFct() int {
return int(C.dm_cookie_supported())
}
func dmLogInitVerboseFct(level int) {
C.dm_log_init_verbose(C.int(level))
}
func logWithErrnoInitFct() {
C.log_with_errno_init()
}
func dmSetDevDirFct(dir string) int {
Cdir := C.CString(dir)
defer free(Cdir)
return int(C.dm_set_dev_dir(Cdir))
}
func dmGetLibraryVersionFct(version *string) int {
buffer := C.CString(string(make([]byte, 128)))
defer free(buffer)
defer func() {
*version = C.GoString(buffer)
}()
return int(C.dm_get_library_version(buffer, 128))
}

View File

@@ -0,0 +1,34 @@
// +build linux,!libdm_no_deferred_remove
package devicemapper
/*
#cgo LDFLAGS: -L. -ldevmapper
#include <libdevmapper.h>
*/
import "C"
// LibraryDeferredRemovalSupport is supported when statically linked.
const LibraryDeferredRemovalSupport = true
func dmTaskDeferredRemoveFct(task *cdmTask) int {
return int(C.dm_task_deferred_remove((*C.struct_dm_task)(task)))
}
func dmTaskGetInfoWithDeferredFct(task *cdmTask, info *Info) int {
Cinfo := C.struct_dm_info{}
defer func() {
info.Exists = int(Cinfo.exists)
info.Suspended = int(Cinfo.suspended)
info.LiveTable = int(Cinfo.live_table)
info.InactiveTable = int(Cinfo.inactive_table)
info.OpenCount = int32(Cinfo.open_count)
info.EventNr = uint32(Cinfo.event_nr)
info.Major = uint32(Cinfo.major)
info.Minor = uint32(Cinfo.minor)
info.ReadOnly = int(Cinfo.read_only)
info.TargetCount = int32(Cinfo.target_count)
info.DeferredRemove = int(Cinfo.deferred_remove)
}()
return int(C.dm_task_get_info((*C.struct_dm_task)(task), &Cinfo))
}

View File

@@ -0,0 +1,15 @@
// +build linux,libdm_no_deferred_remove
package devicemapper
// LibraryDeferredRemovalsupport is not supported when statically linked.
const LibraryDeferredRemovalSupport = false
func dmTaskDeferredRemoveFct(task *cdmTask) int {
// Error. Nobody should be calling it.
return -1
}
func dmTaskGetInfoWithDeferredFct(task *cdmTask, info *Info) int {
return -1
}

View File

@@ -0,0 +1,27 @@
// +build linux
package devicemapper
import (
"syscall"
"unsafe"
)
func ioctlBlkGetSize64(fd uintptr) (int64, error) {
var size int64
if _, _, err := syscall.Syscall(syscall.SYS_IOCTL, fd, BlkGetSize64, uintptr(unsafe.Pointer(&size))); err != 0 {
return 0, err
}
return size, nil
}
func ioctlBlkDiscard(fd uintptr, offset, length uint64) error {
var r [2]uint64
r[0] = offset
r[1] = length
if _, _, err := syscall.Syscall(syscall.SYS_IOCTL, fd, BlkDiscard, uintptr(unsafe.Pointer(&r[0]))); err != 0 {
return err
}
return nil
}

View File

@@ -0,0 +1,11 @@
package devicemapper
// definitions from lvm2 lib/log/log.h
const (
LogLevelFatal = 2 + iota // _LOG_FATAL
LogLevelErr // _LOG_ERR
LogLevelWarn // _LOG_WARN
LogLevelNotice // _LOG_NOTICE
LogLevelInfo // _LOG_INFO
LogLevelDebug // _LOG_DEBUG
)

View File

@@ -0,0 +1,26 @@
package directory
import (
"io/ioutil"
"os"
"path/filepath"
)
// MoveToSubdir moves all contents of a directory to a subdirectory underneath the original path
func MoveToSubdir(oldpath, subdir string) error {
infos, err := ioutil.ReadDir(oldpath)
if err != nil {
return err
}
for _, info := range infos {
if info.Name() != subdir {
oldName := filepath.Join(oldpath, info.Name())
newName := filepath.Join(oldpath, subdir, info.Name())
if err := os.Rename(oldName, newName); err != nil {
return err
}
}
}
return nil
}

View File

@@ -0,0 +1,182 @@
package directory
import (
"io/ioutil"
"os"
"path/filepath"
"reflect"
"sort"
"testing"
)
// Size of an empty directory should be 0
func TestSizeEmpty(t *testing.T) {
var dir string
var err error
if dir, err = ioutil.TempDir(os.TempDir(), "testSizeEmptyDirectory"); err != nil {
t.Fatalf("failed to create directory: %s", err)
}
var size int64
if size, _ = Size(dir); size != 0 {
t.Fatalf("empty directory has size: %d", size)
}
}
// Size of a directory with one empty file should be 0
func TestSizeEmptyFile(t *testing.T) {
var dir string
var err error
if dir, err = ioutil.TempDir(os.TempDir(), "testSizeEmptyFile"); err != nil {
t.Fatalf("failed to create directory: %s", err)
}
var file *os.File
if file, err = ioutil.TempFile(dir, "file"); err != nil {
t.Fatalf("failed to create file: %s", err)
}
var size int64
if size, _ = Size(file.Name()); size != 0 {
t.Fatalf("directory with one file has size: %d", size)
}
}
// Size of a directory with one 5-byte file should be 5
func TestSizeNonemptyFile(t *testing.T) {
var dir string
var err error
if dir, err = ioutil.TempDir(os.TempDir(), "testSizeNonemptyFile"); err != nil {
t.Fatalf("failed to create directory: %s", err)
}
var file *os.File
if file, err = ioutil.TempFile(dir, "file"); err != nil {
t.Fatalf("failed to create file: %s", err)
}
d := []byte{97, 98, 99, 100, 101}
file.Write(d)
var size int64
if size, _ = Size(file.Name()); size != 5 {
t.Fatalf("directory with one 5-byte file has size: %d", size)
}
}
// Size of a directory with one empty directory should be 0
func TestSizeNestedDirectoryEmpty(t *testing.T) {
var dir string
var err error
if dir, err = ioutil.TempDir(os.TempDir(), "testSizeNestedDirectoryEmpty"); err != nil {
t.Fatalf("failed to create directory: %s", err)
}
if dir, err = ioutil.TempDir(dir, "nested"); err != nil {
t.Fatalf("failed to create nested directory: %s", err)
}
var size int64
if size, _ = Size(dir); size != 0 {
t.Fatalf("directory with one empty directory has size: %d", size)
}
}
// Test directory with 1 file and 1 empty directory
func TestSizeFileAndNestedDirectoryEmpty(t *testing.T) {
var dir string
var err error
if dir, err = ioutil.TempDir(os.TempDir(), "testSizeFileAndNestedDirectoryEmpty"); err != nil {
t.Fatalf("failed to create directory: %s", err)
}
if dir, err = ioutil.TempDir(dir, "nested"); err != nil {
t.Fatalf("failed to create nested directory: %s", err)
}
var file *os.File
if file, err = ioutil.TempFile(dir, "file"); err != nil {
t.Fatalf("failed to create file: %s", err)
}
d := []byte{100, 111, 99, 107, 101, 114}
file.Write(d)
var size int64
if size, _ = Size(dir); size != 6 {
t.Fatalf("directory with 6-byte file and empty directory has size: %d", size)
}
}
// Test directory with 1 file and 1 non-empty directory
func TestSizeFileAndNestedDirectoryNonempty(t *testing.T) {
var dir, dirNested string
var err error
if dir, err = ioutil.TempDir(os.TempDir(), "TestSizeFileAndNestedDirectoryNonempty"); err != nil {
t.Fatalf("failed to create directory: %s", err)
}
if dirNested, err = ioutil.TempDir(dir, "nested"); err != nil {
t.Fatalf("failed to create nested directory: %s", err)
}
var file *os.File
if file, err = ioutil.TempFile(dir, "file"); err != nil {
t.Fatalf("failed to create file: %s", err)
}
data := []byte{100, 111, 99, 107, 101, 114}
file.Write(data)
var nestedFile *os.File
if nestedFile, err = ioutil.TempFile(dirNested, "file"); err != nil {
t.Fatalf("failed to create file in nested directory: %s", err)
}
nestedData := []byte{100, 111, 99, 107, 101, 114}
nestedFile.Write(nestedData)
var size int64
if size, _ = Size(dir); size != 12 {
t.Fatalf("directory with 6-byte file and nested directory with 6-byte file has size: %d", size)
}
}
// Test migration of directory to a subdir underneath itself
func TestMoveToSubdir(t *testing.T) {
var outerDir, subDir string
var err error
if outerDir, err = ioutil.TempDir(os.TempDir(), "TestMoveToSubdir"); err != nil {
t.Fatalf("failed to create directory: %v", err)
}
if subDir, err = ioutil.TempDir(outerDir, "testSub"); err != nil {
t.Fatalf("failed to create subdirectory: %v", err)
}
// write 4 temp files in the outer dir to get moved
filesList := []string{"a", "b", "c", "d"}
for _, fName := range filesList {
if file, err := os.Create(filepath.Join(outerDir, fName)); err != nil {
t.Fatalf("couldn't create temp file %q: %v", fName, err)
} else {
file.WriteString(fName)
file.Close()
}
}
if err = MoveToSubdir(outerDir, filepath.Base(subDir)); err != nil {
t.Fatalf("Error during migration of content to subdirectory: %v", err)
}
// validate that the files were moved to the subdirectory
infos, err := ioutil.ReadDir(subDir)
if len(infos) != 4 {
t.Fatalf("Should be four files in the subdir after the migration: actual length: %d", len(infos))
}
var results []string
for _, info := range infos {
results = append(results, info.Name())
}
sort.Sort(sort.StringSlice(results))
if !reflect.DeepEqual(filesList, results) {
t.Fatalf("Results after migration do not equal list of files: expected: %v, got: %v", filesList, results)
}
}

View File

@@ -0,0 +1,39 @@
// +build linux freebsd
package directory
import (
"os"
"path/filepath"
"syscall"
)
// Size walks a directory tree and returns its total size in bytes.
func Size(dir string) (size int64, err error) {
data := make(map[uint64]struct{})
err = filepath.Walk(dir, func(d string, fileInfo os.FileInfo, e error) error {
// Ignore directory sizes
if fileInfo == nil {
return nil
}
s := fileInfo.Size()
if fileInfo.IsDir() || s == 0 {
return nil
}
// Check inode to handle hard links correctly
inode := fileInfo.Sys().(*syscall.Stat_t).Ino
// inode is not a uint64 on all platforms. Cast it to avoid issues.
if _, exists := data[uint64(inode)]; exists {
return nil
}
// inode is not a uint64 on all platforms. Cast it to avoid issues.
data[uint64(inode)] = struct{}{}
size += s
return nil
})
return
}

View File

@@ -0,0 +1,35 @@
// +build windows
package directory
import (
"os"
"path/filepath"
"github.com/hyperhq/hypercli/pkg/longpath"
)
// Size walks a directory tree and returns its total size in bytes.
func Size(dir string) (size int64, err error) {
fixedPath, err := filepath.Abs(dir)
if err != nil {
return
}
fixedPath = longpath.AddPrefix(fixedPath)
err = filepath.Walk(dir, func(d string, fileInfo os.FileInfo, e error) error {
// Ignore directory sizes
if fileInfo == nil {
return nil
}
s := fileInfo.Size()
if fileInfo.IsDir() || s == 0 {
return nil
}
size += s
return nil
})
return
}

View File

@@ -0,0 +1,41 @@
---
page_title: Docker discovery
page_description: discovery
page_keywords: docker, clustering, discovery
---
# Discovery
Docker comes with multiple Discovery backends.
## Backends
### Using etcd
Point your Docker Engine instances to a common etcd instance. You can specify
the address Docker uses to advertise the node using the `--cluster-advertise`
flag.
```bash
$ docker daemon -H=<node_ip:2376> --cluster-advertise=<node_ip:2376> --cluster-store etcd://<etcd_ip1>,<etcd_ip2>/<path>
```
### Using consul
Point your Docker Engine instances to a common Consul instance. You can specify
the address Docker uses to advertise the node using the `--cluster-advertise`
flag.
```bash
$ docker daemon -H=<node_ip:2376> --cluster-advertise=<node_ip:2376> --cluster-store consul://<consul_ip>/<path>
```
### Using zookeeper
Point your Docker Engine instances to a common Zookeeper instance. You can specify
the address Docker uses to advertise the node using the `--cluster-advertise`
flag.
```bash
$ docker daemon -H=<node_ip:2376> --cluster-advertise=<node_ip:2376> --cluster-store zk://<zk_addr1>,<zk_addr2>/<path>
```

View File

@@ -0,0 +1,107 @@
package discovery
import (
"fmt"
"net"
"strings"
"time"
log "github.com/Sirupsen/logrus"
)
var (
// Backends is a global map of discovery backends indexed by their
// associated scheme.
backends = make(map[string]Backend)
)
// Register makes a discovery backend available by the provided scheme.
// If Register is called twice with the same scheme an error is returned.
func Register(scheme string, d Backend) error {
if _, exists := backends[scheme]; exists {
return fmt.Errorf("scheme already registered %s", scheme)
}
log.WithField("name", scheme).Debug("Registering discovery service")
backends[scheme] = d
return nil
}
func parse(rawurl string) (string, string) {
parts := strings.SplitN(rawurl, "://", 2)
// nodes:port,node2:port => nodes://node1:port,node2:port
if len(parts) == 1 {
return "nodes", parts[0]
}
return parts[0], parts[1]
}
// ParseAdvertise parses the --cluster-advertise daemon config which accepts
// <ip-address>:<port> or <interface-name>:<port>
func ParseAdvertise(advertise string) (string, error) {
var (
iface *net.Interface
addrs []net.Addr
err error
)
addr, port, err := net.SplitHostPort(advertise)
if err != nil {
return "", fmt.Errorf("invalid --cluster-advertise configuration: %s: %v", advertise, err)
}
ip := net.ParseIP(addr)
// If it is a valid ip-address, use it as is
if ip != nil {
return advertise, nil
}
// If advertise is a valid interface name, get the valid ipv4 address and use it to advertise
ifaceName := addr
iface, err = net.InterfaceByName(ifaceName)
if err != nil {
return "", fmt.Errorf("invalid cluster advertise IP address or interface name (%s) : %v", advertise, err)
}
addrs, err = iface.Addrs()
if err != nil {
return "", fmt.Errorf("unable to get advertise IP address from interface (%s) : %v", advertise, err)
}
if addrs == nil || len(addrs) == 0 {
return "", fmt.Errorf("no available advertise IP address in interface (%s)", advertise)
}
addr = ""
for _, a := range addrs {
ip, _, err := net.ParseCIDR(a.String())
if err != nil {
return "", fmt.Errorf("error deriving advertise ip-address in interface (%s) : %v", advertise, err)
}
if ip.To4() == nil || ip.IsLoopback() {
continue
}
addr = ip.String()
break
}
if addr == "" {
return "", fmt.Errorf("couldnt find a valid ip-address in interface %s", advertise)
}
addr = fmt.Sprintf("%s:%s", addr, port)
return addr, nil
}
// New returns a new Discovery given a URL, heartbeat and ttl settings.
// Returns an error if the URL scheme is not supported.
func New(rawurl string, heartbeat time.Duration, ttl time.Duration, clusterOpts map[string]string) (Backend, error) {
scheme, uri := parse(rawurl)
if backend, exists := backends[scheme]; exists {
log.WithFields(log.Fields{"name": scheme, "uri": uri}).Debug("Initializing discovery service")
err := backend.Initialize(uri, heartbeat, ttl, clusterOpts)
return backend, err
}
return nil, ErrNotSupported
}

View File

@@ -0,0 +1,35 @@
package discovery
import (
"errors"
"time"
)
var (
// ErrNotSupported is returned when a discovery service is not supported.
ErrNotSupported = errors.New("discovery service not supported")
// ErrNotImplemented is returned when discovery feature is not implemented
// by discovery backend.
ErrNotImplemented = errors.New("not implemented in this discovery service")
)
// Watcher provides watching over a cluster for nodes joining and leaving.
type Watcher interface {
// Watch the discovery for entry changes.
// Returns a channel that will receive changes or an error.
// Providing a non-nil stopCh can be used to stop watching.
Watch(stopCh <-chan struct{}) (<-chan Entries, <-chan error)
}
// Backend is implemented by discovery backends which manage cluster entries.
type Backend interface {
// Watcher must be provided by every backend.
Watcher
// Initialize the discovery with URIs, a heartbeat, a ttl and optional settings.
Initialize(string, time.Duration, time.Duration, map[string]string) error
// Register to the discovery.
Register(string) error
}

View File

@@ -0,0 +1,131 @@
package discovery
import (
"testing"
"github.com/go-check/check"
)
// Hook up gocheck into the "go test" runner.
func Test(t *testing.T) { check.TestingT(t) }
type DiscoverySuite struct{}
var _ = check.Suite(&DiscoverySuite{})
func (s *DiscoverySuite) TestNewEntry(c *check.C) {
entry, err := NewEntry("127.0.0.1:2375")
c.Assert(err, check.IsNil)
c.Assert(entry.Equals(&Entry{Host: "127.0.0.1", Port: "2375"}), check.Equals, true)
c.Assert(entry.String(), check.Equals, "127.0.0.1:2375")
_, err = NewEntry("127.0.0.1")
c.Assert(err, check.NotNil)
}
func (s *DiscoverySuite) TestParse(c *check.C) {
scheme, uri := parse("127.0.0.1:2375")
c.Assert(scheme, check.Equals, "nodes")
c.Assert(uri, check.Equals, "127.0.0.1:2375")
scheme, uri = parse("localhost:2375")
c.Assert(scheme, check.Equals, "nodes")
c.Assert(uri, check.Equals, "localhost:2375")
scheme, uri = parse("scheme://127.0.0.1:2375")
c.Assert(scheme, check.Equals, "scheme")
c.Assert(uri, check.Equals, "127.0.0.1:2375")
scheme, uri = parse("scheme://localhost:2375")
c.Assert(scheme, check.Equals, "scheme")
c.Assert(uri, check.Equals, "localhost:2375")
scheme, uri = parse("")
c.Assert(scheme, check.Equals, "nodes")
c.Assert(uri, check.Equals, "")
}
func (s *DiscoverySuite) TestCreateEntries(c *check.C) {
entries, err := CreateEntries(nil)
c.Assert(entries, check.DeepEquals, Entries{})
c.Assert(err, check.IsNil)
entries, err = CreateEntries([]string{"127.0.0.1:2375", "127.0.0.2:2375", ""})
c.Assert(err, check.IsNil)
expected := Entries{
&Entry{Host: "127.0.0.1", Port: "2375"},
&Entry{Host: "127.0.0.2", Port: "2375"},
}
c.Assert(entries.Equals(expected), check.Equals, true)
_, err = CreateEntries([]string{"127.0.0.1", "127.0.0.2"})
c.Assert(err, check.NotNil)
}
func (s *DiscoverySuite) TestContainsEntry(c *check.C) {
entries, err := CreateEntries([]string{"127.0.0.1:2375", "127.0.0.2:2375", ""})
c.Assert(err, check.IsNil)
c.Assert(entries.Contains(&Entry{Host: "127.0.0.1", Port: "2375"}), check.Equals, true)
c.Assert(entries.Contains(&Entry{Host: "127.0.0.3", Port: "2375"}), check.Equals, false)
}
func (s *DiscoverySuite) TestEntriesEquality(c *check.C) {
entries := Entries{
&Entry{Host: "127.0.0.1", Port: "2375"},
&Entry{Host: "127.0.0.2", Port: "2375"},
}
// Same
c.Assert(entries.Equals(Entries{
&Entry{Host: "127.0.0.1", Port: "2375"},
&Entry{Host: "127.0.0.2", Port: "2375"},
}), check.
Equals, true)
// Different size
c.Assert(entries.Equals(Entries{
&Entry{Host: "127.0.0.1", Port: "2375"},
&Entry{Host: "127.0.0.2", Port: "2375"},
&Entry{Host: "127.0.0.3", Port: "2375"},
}), check.
Equals, false)
// Different content
c.Assert(entries.Equals(Entries{
&Entry{Host: "127.0.0.1", Port: "2375"},
&Entry{Host: "127.0.0.42", Port: "2375"},
}), check.
Equals, false)
}
func (s *DiscoverySuite) TestEntriesDiff(c *check.C) {
entry1 := &Entry{Host: "1.1.1.1", Port: "1111"}
entry2 := &Entry{Host: "2.2.2.2", Port: "2222"}
entry3 := &Entry{Host: "3.3.3.3", Port: "3333"}
entries := Entries{entry1, entry2}
// No diff
added, removed := entries.Diff(Entries{entry2, entry1})
c.Assert(added, check.HasLen, 0)
c.Assert(removed, check.HasLen, 0)
// Add
added, removed = entries.Diff(Entries{entry2, entry3, entry1})
c.Assert(added, check.HasLen, 1)
c.Assert(added.Contains(entry3), check.Equals, true)
c.Assert(removed, check.HasLen, 0)
// Remove
added, removed = entries.Diff(Entries{entry2})
c.Assert(added, check.HasLen, 0)
c.Assert(removed, check.HasLen, 1)
c.Assert(removed.Contains(entry1), check.Equals, true)
// Add and remove
added, removed = entries.Diff(Entries{entry1, entry3})
c.Assert(added, check.HasLen, 1)
c.Assert(added.Contains(entry3), check.Equals, true)
c.Assert(removed, check.HasLen, 1)
c.Assert(removed.Contains(entry2), check.Equals, true)
}

View File

@@ -0,0 +1,97 @@
package discovery
import (
"fmt"
"net"
)
// NewEntry creates a new entry.
func NewEntry(url string) (*Entry, error) {
host, port, err := net.SplitHostPort(url)
if err != nil {
return nil, err
}
return &Entry{host, port}, nil
}
// An Entry represents a host.
type Entry struct {
Host string
Port string
}
// Equals returns true if cmp contains the same data.
func (e *Entry) Equals(cmp *Entry) bool {
return e.Host == cmp.Host && e.Port == cmp.Port
}
// String returns the string form of an entry.
func (e *Entry) String() string {
return fmt.Sprintf("%s:%s", e.Host, e.Port)
}
// Entries is a list of *Entry with some helpers.
type Entries []*Entry
// Equals returns true if cmp contains the same data.
func (e Entries) Equals(cmp Entries) bool {
// Check if the file has really changed.
if len(e) != len(cmp) {
return false
}
for i := range e {
if !e[i].Equals(cmp[i]) {
return false
}
}
return true
}
// Contains returns true if the Entries contain a given Entry.
func (e Entries) Contains(entry *Entry) bool {
for _, curr := range e {
if curr.Equals(entry) {
return true
}
}
return false
}
// Diff compares two entries and returns the added and removed entries.
func (e Entries) Diff(cmp Entries) (Entries, Entries) {
added := Entries{}
for _, entry := range cmp {
if !e.Contains(entry) {
added = append(added, entry)
}
}
removed := Entries{}
for _, entry := range e {
if !cmp.Contains(entry) {
removed = append(removed, entry)
}
}
return added, removed
}
// CreateEntries returns an array of entries based on the given addresses.
func CreateEntries(addrs []string) (Entries, error) {
entries := Entries{}
if addrs == nil {
return entries, nil
}
for _, addr := range addrs {
if len(addr) == 0 {
continue
}
entry, err := NewEntry(addr)
if err != nil {
return nil, err
}
entries = append(entries, entry)
}
return entries, nil
}

View File

@@ -0,0 +1,109 @@
package file
import (
"fmt"
"io/ioutil"
"strings"
"time"
"github.com/hyperhq/hypercli/pkg/discovery"
)
// Discovery is exported
type Discovery struct {
heartbeat time.Duration
path string
}
func init() {
Init()
}
// Init is exported
func Init() {
discovery.Register("file", &Discovery{})
}
// Initialize is exported
func (s *Discovery) Initialize(path string, heartbeat time.Duration, ttl time.Duration, _ map[string]string) error {
s.path = path
s.heartbeat = heartbeat
return nil
}
func parseFileContent(content []byte) []string {
var result []string
for _, line := range strings.Split(strings.TrimSpace(string(content)), "\n") {
line = strings.TrimSpace(line)
// Ignoring line starts with #
if strings.HasPrefix(line, "#") {
continue
}
// Inlined # comment also ignored.
if strings.Contains(line, "#") {
line = line[0:strings.Index(line, "#")]
// Trim additional spaces caused by above stripping.
line = strings.TrimSpace(line)
}
for _, ip := range discovery.Generate(line) {
result = append(result, ip)
}
}
return result
}
func (s *Discovery) fetch() (discovery.Entries, error) {
fileContent, err := ioutil.ReadFile(s.path)
if err != nil {
return nil, fmt.Errorf("failed to read '%s': %v", s.path, err)
}
return discovery.CreateEntries(parseFileContent(fileContent))
}
// Watch is exported
func (s *Discovery) Watch(stopCh <-chan struct{}) (<-chan discovery.Entries, <-chan error) {
ch := make(chan discovery.Entries)
errCh := make(chan error)
ticker := time.NewTicker(s.heartbeat)
go func() {
defer close(errCh)
defer close(ch)
// Send the initial entries if available.
currentEntries, err := s.fetch()
if err != nil {
errCh <- err
} else {
ch <- currentEntries
}
// Periodically send updates.
for {
select {
case <-ticker.C:
newEntries, err := s.fetch()
if err != nil {
errCh <- err
continue
}
// Check if the file has really changed.
if !newEntries.Equals(currentEntries) {
ch <- newEntries
}
currentEntries = newEntries
case <-stopCh:
ticker.Stop()
return
}
}
}()
return ch, errCh
}
// Register is exported
func (s *Discovery) Register(addr string) error {
return discovery.ErrNotImplemented
}

View File

@@ -0,0 +1,114 @@
package file
import (
"io/ioutil"
"os"
"testing"
"github.com/hyperhq/hypercli/pkg/discovery"
"github.com/go-check/check"
)
// Hook up gocheck into the "go test" runner.
func Test(t *testing.T) { check.TestingT(t) }
type DiscoverySuite struct{}
var _ = check.Suite(&DiscoverySuite{})
func (s *DiscoverySuite) TestInitialize(c *check.C) {
d := &Discovery{}
d.Initialize("/path/to/file", 1000, 0, nil)
c.Assert(d.path, check.Equals, "/path/to/file")
}
func (s *DiscoverySuite) TestNew(c *check.C) {
d, err := discovery.New("file:///path/to/file", 0, 0, nil)
c.Assert(err, check.IsNil)
c.Assert(d.(*Discovery).path, check.Equals, "/path/to/file")
}
func (s *DiscoverySuite) TestContent(c *check.C) {
data := `
1.1.1.[1:2]:1111
2.2.2.[2:4]:2222
`
ips := parseFileContent([]byte(data))
c.Assert(ips, check.HasLen, 5)
c.Assert(ips[0], check.Equals, "1.1.1.1:1111")
c.Assert(ips[1], check.Equals, "1.1.1.2:1111")
c.Assert(ips[2], check.Equals, "2.2.2.2:2222")
c.Assert(ips[3], check.Equals, "2.2.2.3:2222")
c.Assert(ips[4], check.Equals, "2.2.2.4:2222")
}
func (s *DiscoverySuite) TestRegister(c *check.C) {
discovery := &Discovery{path: "/path/to/file"}
c.Assert(discovery.Register("0.0.0.0"), check.NotNil)
}
func (s *DiscoverySuite) TestParsingContentsWithComments(c *check.C) {
data := `
### test ###
1.1.1.1:1111 # inline comment
# 2.2.2.2:2222
### empty line with comment
3.3.3.3:3333
### test ###
`
ips := parseFileContent([]byte(data))
c.Assert(ips, check.HasLen, 2)
c.Assert("1.1.1.1:1111", check.Equals, ips[0])
c.Assert("3.3.3.3:3333", check.Equals, ips[1])
}
func (s *DiscoverySuite) TestWatch(c *check.C) {
data := `
1.1.1.1:1111
2.2.2.2:2222
`
expected := discovery.Entries{
&discovery.Entry{Host: "1.1.1.1", Port: "1111"},
&discovery.Entry{Host: "2.2.2.2", Port: "2222"},
}
// Create a temporary file and remove it.
tmp, err := ioutil.TempFile(os.TempDir(), "discovery-file-test")
c.Assert(err, check.IsNil)
c.Assert(tmp.Close(), check.IsNil)
c.Assert(os.Remove(tmp.Name()), check.IsNil)
// Set up file discovery.
d := &Discovery{}
d.Initialize(tmp.Name(), 1000, 0, nil)
stopCh := make(chan struct{})
ch, errCh := d.Watch(stopCh)
// Make sure it fires errors since the file doesn't exist.
c.Assert(<-errCh, check.NotNil)
// We have to drain the error channel otherwise Watch will get stuck.
go func() {
for range errCh {
}
}()
// Write the file and make sure we get the expected value back.
c.Assert(ioutil.WriteFile(tmp.Name(), []byte(data), 0600), check.IsNil)
c.Assert(<-ch, check.DeepEquals, expected)
// Add a new entry and look it up.
expected = append(expected, &discovery.Entry{Host: "3.3.3.3", Port: "3333"})
f, err := os.OpenFile(tmp.Name(), os.O_APPEND|os.O_WRONLY, 0600)
c.Assert(err, check.IsNil)
c.Assert(f, check.NotNil)
_, err = f.WriteString("\n3.3.3.3:3333\n")
c.Assert(err, check.IsNil)
f.Close()
c.Assert(<-ch, check.DeepEquals, expected)
// Stop and make sure it closes all channels.
close(stopCh)
c.Assert(<-ch, check.IsNil)
c.Assert(<-errCh, check.IsNil)
}

View File

@@ -0,0 +1,35 @@
package discovery
import (
"fmt"
"regexp"
"strconv"
)
// Generate takes care of IP generation
func Generate(pattern string) []string {
re, _ := regexp.Compile(`\[(.+):(.+)\]`)
submatch := re.FindStringSubmatch(pattern)
if submatch == nil {
return []string{pattern}
}
from, err := strconv.Atoi(submatch[1])
if err != nil {
return []string{pattern}
}
to, err := strconv.Atoi(submatch[2])
if err != nil {
return []string{pattern}
}
template := re.ReplaceAllString(pattern, "%d")
var result []string
for val := from; val <= to; val++ {
entry := fmt.Sprintf(template, val)
result = append(result, entry)
}
return result
}

View File

@@ -0,0 +1,53 @@
package discovery
import (
"github.com/go-check/check"
)
func (s *DiscoverySuite) TestGeneratorNotGenerate(c *check.C) {
ips := Generate("127.0.0.1")
c.Assert(len(ips), check.Equals, 1)
c.Assert(ips[0], check.Equals, "127.0.0.1")
}
func (s *DiscoverySuite) TestGeneratorWithPortNotGenerate(c *check.C) {
ips := Generate("127.0.0.1:8080")
c.Assert(len(ips), check.Equals, 1)
c.Assert(ips[0], check.Equals, "127.0.0.1:8080")
}
func (s *DiscoverySuite) TestGeneratorMatchFailedNotGenerate(c *check.C) {
ips := Generate("127.0.0.[1]")
c.Assert(len(ips), check.Equals, 1)
c.Assert(ips[0], check.Equals, "127.0.0.[1]")
}
func (s *DiscoverySuite) TestGeneratorWithPort(c *check.C) {
ips := Generate("127.0.0.[1:11]:2375")
c.Assert(len(ips), check.Equals, 11)
c.Assert(ips[0], check.Equals, "127.0.0.1:2375")
c.Assert(ips[1], check.Equals, "127.0.0.2:2375")
c.Assert(ips[2], check.Equals, "127.0.0.3:2375")
c.Assert(ips[3], check.Equals, "127.0.0.4:2375")
c.Assert(ips[4], check.Equals, "127.0.0.5:2375")
c.Assert(ips[5], check.Equals, "127.0.0.6:2375")
c.Assert(ips[6], check.Equals, "127.0.0.7:2375")
c.Assert(ips[7], check.Equals, "127.0.0.8:2375")
c.Assert(ips[8], check.Equals, "127.0.0.9:2375")
c.Assert(ips[9], check.Equals, "127.0.0.10:2375")
c.Assert(ips[10], check.Equals, "127.0.0.11:2375")
}
func (s *DiscoverySuite) TestGenerateWithMalformedInputAtRangeStart(c *check.C) {
malformedInput := "127.0.0.[x:11]:2375"
ips := Generate(malformedInput)
c.Assert(len(ips), check.Equals, 1)
c.Assert(ips[0], check.Equals, malformedInput)
}
func (s *DiscoverySuite) TestGenerateWithMalformedInputAtRangeEnd(c *check.C) {
malformedInput := "127.0.0.[1:x]:2375"
ips := Generate(malformedInput)
c.Assert(len(ips), check.Equals, 1)
c.Assert(ips[0], check.Equals, malformedInput)
}

View File

@@ -0,0 +1,192 @@
package kv
import (
"fmt"
"path"
"strings"
"time"
log "github.com/Sirupsen/logrus"
"github.com/hyperhq/hypercli/pkg/discovery"
"github.com/docker/go-connections/tlsconfig"
"github.com/docker/libkv"
"github.com/docker/libkv/store"
"github.com/docker/libkv/store/consul"
"github.com/docker/libkv/store/etcd"
"github.com/docker/libkv/store/zookeeper"
)
const (
defaultDiscoveryPath = "docker/nodes"
)
// Discovery is exported
type Discovery struct {
backend store.Backend
store store.Store
heartbeat time.Duration
ttl time.Duration
prefix string
path string
}
func init() {
Init()
}
// Init is exported
func Init() {
// Register to libkv
zookeeper.Register()
consul.Register()
etcd.Register()
// Register to internal discovery service
discovery.Register("zk", &Discovery{backend: store.ZK})
discovery.Register("consul", &Discovery{backend: store.CONSUL})
discovery.Register("etcd", &Discovery{backend: store.ETCD})
}
// Initialize is exported
func (s *Discovery) Initialize(uris string, heartbeat time.Duration, ttl time.Duration, clusterOpts map[string]string) error {
var (
parts = strings.SplitN(uris, "/", 2)
addrs = strings.Split(parts[0], ",")
err error
)
// A custom prefix to the path can be optionally used.
if len(parts) == 2 {
s.prefix = parts[1]
}
s.heartbeat = heartbeat
s.ttl = ttl
// Use a custom path if specified in discovery options
dpath := defaultDiscoveryPath
if clusterOpts["kv.path"] != "" {
dpath = clusterOpts["kv.path"]
}
s.path = path.Join(s.prefix, dpath)
var config *store.Config
if clusterOpts["kv.cacertfile"] != "" && clusterOpts["kv.certfile"] != "" && clusterOpts["kv.keyfile"] != "" {
log.Info("Initializing discovery with TLS")
tlsConfig, err := tlsconfig.Client(tlsconfig.Options{
CAFile: clusterOpts["kv.cacertfile"],
CertFile: clusterOpts["kv.certfile"],
KeyFile: clusterOpts["kv.keyfile"],
})
if err != nil {
return err
}
config = &store.Config{
// Set ClientTLS to trigger https (bug in libkv/etcd)
ClientTLS: &store.ClientTLSConfig{
CACertFile: clusterOpts["kv.cacertfile"],
CertFile: clusterOpts["kv.certfile"],
KeyFile: clusterOpts["kv.keyfile"],
},
// The actual TLS config that will be used
TLS: tlsConfig,
}
} else {
log.Info("Initializing discovery without TLS")
}
// Creates a new store, will ignore options given
// if not supported by the chosen store
s.store, err = libkv.NewStore(s.backend, addrs, config)
return err
}
// Watch the store until either there's a store error or we receive a stop request.
// Returns false if we shouldn't attempt watching the store anymore (stop request received).
func (s *Discovery) watchOnce(stopCh <-chan struct{}, watchCh <-chan []*store.KVPair, discoveryCh chan discovery.Entries, errCh chan error) bool {
for {
select {
case pairs := <-watchCh:
if pairs == nil {
return true
}
log.WithField("discovery", s.backend).Debugf("Watch triggered with %d nodes", len(pairs))
// Convert `KVPair` into `discovery.Entry`.
addrs := make([]string, len(pairs))
for _, pair := range pairs {
addrs = append(addrs, string(pair.Value))
}
entries, err := discovery.CreateEntries(addrs)
if err != nil {
errCh <- err
} else {
discoveryCh <- entries
}
case <-stopCh:
// We were requested to stop watching.
return false
}
}
}
// Watch is exported
func (s *Discovery) Watch(stopCh <-chan struct{}) (<-chan discovery.Entries, <-chan error) {
ch := make(chan discovery.Entries)
errCh := make(chan error)
go func() {
defer close(ch)
defer close(errCh)
// Forever: Create a store watch, watch until we get an error and then try again.
// Will only stop if we receive a stopCh request.
for {
// Create the path to watch if it does not exist yet
exists, err := s.store.Exists(s.path)
if err != nil {
errCh <- err
}
if !exists {
if err := s.store.Put(s.path, []byte(""), &store.WriteOptions{IsDir: true}); err != nil {
errCh <- err
}
}
// Set up a watch.
watchCh, err := s.store.WatchTree(s.path, stopCh)
if err != nil {
errCh <- err
} else {
if !s.watchOnce(stopCh, watchCh, ch, errCh) {
return
}
}
// If we get here it means the store watch channel was closed. This
// is unexpected so let's retry later.
errCh <- fmt.Errorf("Unexpected watch error")
time.Sleep(s.heartbeat)
}
}()
return ch, errCh
}
// Register is exported
func (s *Discovery) Register(addr string) error {
opts := &store.WriteOptions{TTL: s.ttl}
return s.store.Put(path.Join(s.path, addr), []byte(addr), opts)
}
// Store returns the underlying store used by KV discovery.
func (s *Discovery) Store() store.Store {
return s.store
}
// Prefix returns the store prefix
func (s *Discovery) Prefix() string {
return s.prefix
}

View File

@@ -0,0 +1,324 @@
package kv
import (
"errors"
"io/ioutil"
"os"
"path"
"testing"
"time"
"github.com/hyperhq/hypercli/pkg/discovery"
"github.com/docker/libkv"
"github.com/docker/libkv/store"
"github.com/go-check/check"
)
// Hook up gocheck into the "go test" runner.
func Test(t *testing.T) { check.TestingT(t) }
type DiscoverySuite struct{}
var _ = check.Suite(&DiscoverySuite{})
func (ds *DiscoverySuite) TestInitialize(c *check.C) {
storeMock := &FakeStore{
Endpoints: []string{"127.0.0.1"},
}
d := &Discovery{backend: store.CONSUL}
d.Initialize("127.0.0.1", 0, 0, nil)
d.store = storeMock
s := d.store.(*FakeStore)
c.Assert(s.Endpoints, check.HasLen, 1)
c.Assert(s.Endpoints[0], check.Equals, "127.0.0.1")
c.Assert(d.path, check.Equals, defaultDiscoveryPath)
storeMock = &FakeStore{
Endpoints: []string{"127.0.0.1:1234"},
}
d = &Discovery{backend: store.CONSUL}
d.Initialize("127.0.0.1:1234/path", 0, 0, nil)
d.store = storeMock
s = d.store.(*FakeStore)
c.Assert(s.Endpoints, check.HasLen, 1)
c.Assert(s.Endpoints[0], check.Equals, "127.0.0.1:1234")
c.Assert(d.path, check.Equals, "path/"+defaultDiscoveryPath)
storeMock = &FakeStore{
Endpoints: []string{"127.0.0.1:1234", "127.0.0.2:1234", "127.0.0.3:1234"},
}
d = &Discovery{backend: store.CONSUL}
d.Initialize("127.0.0.1:1234,127.0.0.2:1234,127.0.0.3:1234/path", 0, 0, nil)
d.store = storeMock
s = d.store.(*FakeStore)
c.Assert(s.Endpoints, check.HasLen, 3)
c.Assert(s.Endpoints[0], check.Equals, "127.0.0.1:1234")
c.Assert(s.Endpoints[1], check.Equals, "127.0.0.2:1234")
c.Assert(s.Endpoints[2], check.Equals, "127.0.0.3:1234")
c.Assert(d.path, check.Equals, "path/"+defaultDiscoveryPath)
}
// Extremely limited mock store so we can test initialization
type Mock struct {
// Endpoints passed to InitializeMock
Endpoints []string
// Options passed to InitializeMock
Options *store.Config
}
func NewMock(endpoints []string, options *store.Config) (store.Store, error) {
s := &Mock{}
s.Endpoints = endpoints
s.Options = options
return s, nil
}
func (s *Mock) Put(key string, value []byte, opts *store.WriteOptions) error {
return errors.New("Put not supported")
}
func (s *Mock) Get(key string) (*store.KVPair, error) {
return nil, errors.New("Get not supported")
}
func (s *Mock) Delete(key string) error {
return errors.New("Delete not supported")
}
// Exists mock
func (s *Mock) Exists(key string) (bool, error) {
return false, errors.New("Exists not supported")
}
// Watch mock
func (s *Mock) Watch(key string, stopCh <-chan struct{}) (<-chan *store.KVPair, error) {
return nil, errors.New("Watch not supported")
}
// WatchTree mock
func (s *Mock) WatchTree(prefix string, stopCh <-chan struct{}) (<-chan []*store.KVPair, error) {
return nil, errors.New("WatchTree not supported")
}
// NewLock mock
func (s *Mock) NewLock(key string, options *store.LockOptions) (store.Locker, error) {
return nil, errors.New("NewLock not supported")
}
// List mock
func (s *Mock) List(prefix string) ([]*store.KVPair, error) {
return nil, errors.New("List not supported")
}
// DeleteTree mock
func (s *Mock) DeleteTree(prefix string) error {
return errors.New("DeleteTree not supported")
}
// AtomicPut mock
func (s *Mock) AtomicPut(key string, value []byte, previous *store.KVPair, opts *store.WriteOptions) (bool, *store.KVPair, error) {
return false, nil, errors.New("AtomicPut not supported")
}
// AtomicDelete mock
func (s *Mock) AtomicDelete(key string, previous *store.KVPair) (bool, error) {
return false, errors.New("AtomicDelete not supported")
}
// Close mock
func (s *Mock) Close() {
return
}
func (ds *DiscoverySuite) TestInitializeWithCerts(c *check.C) {
cert := `-----BEGIN CERTIFICATE-----
MIIDCDCCAfKgAwIBAgIICifG7YeiQOEwCwYJKoZIhvcNAQELMBIxEDAOBgNVBAMT
B1Rlc3QgQ0EwHhcNMTUxMDAxMjMwMDAwWhcNMjAwOTI5MjMwMDAwWjASMRAwDgYD
VQQDEwdUZXN0IENBMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEA1wRC
O+flnLTK5ImjTurNRHwSejuqGbc4CAvpB0hS+z0QlSs4+zE9h80aC4hz+6caRpds
+J908Q+RvAittMHbpc7VjbZP72G6fiXk7yPPl6C10HhRSoSi3nY+B7F2E8cuz14q
V2e+ejhWhSrBb/keyXpcyjoW1BOAAJ2TIclRRkICSCZrpXUyXxAvzXfpFXo1RhSb
UywN11pfiCQzDUN7sPww9UzFHuAHZHoyfTr27XnJYVUerVYrCPq8vqfn//01qz55
Xs0hvzGdlTFXhuabFtQnKFH5SNwo/fcznhB7rePOwHojxOpXTBepUCIJLbtNnWFT
V44t9gh5IqIWtoBReQIDAQABo2YwZDAOBgNVHQ8BAf8EBAMCAAYwEgYDVR0TAQH/
BAgwBgEB/wIBAjAdBgNVHQ4EFgQUZKUI8IIjIww7X/6hvwggQK4bD24wHwYDVR0j
BBgwFoAUZKUI8IIjIww7X/6hvwggQK4bD24wCwYJKoZIhvcNAQELA4IBAQDES2cz
7sCQfDCxCIWH7X8kpi/JWExzUyQEJ0rBzN1m3/x8ySRxtXyGekimBqQwQdFqlwMI
xzAQKkh3ue8tNSzRbwqMSyH14N1KrSxYS9e9szJHfUasoTpQGPmDmGIoRJuq1h6M
ej5x1SCJ7GWCR6xEXKUIE9OftXm9TdFzWa7Ja3OHz/mXteii8VXDuZ5ACq6EE5bY
8sP4gcICfJ5fTrpTlk9FIqEWWQrCGa5wk95PGEj+GJpNogjXQ97wVoo/Y3p1brEn
t5zjN9PAq4H1fuCMdNNA+p1DHNwd+ELTxcMAnb2ajwHvV6lKPXutrTFc4umJToBX
FpTxDmJHEV4bzUzh
-----END CERTIFICATE-----
`
key := `-----BEGIN RSA PRIVATE KEY-----
MIIEpQIBAAKCAQEA1wRCO+flnLTK5ImjTurNRHwSejuqGbc4CAvpB0hS+z0QlSs4
+zE9h80aC4hz+6caRpds+J908Q+RvAittMHbpc7VjbZP72G6fiXk7yPPl6C10HhR
SoSi3nY+B7F2E8cuz14qV2e+ejhWhSrBb/keyXpcyjoW1BOAAJ2TIclRRkICSCZr
pXUyXxAvzXfpFXo1RhSbUywN11pfiCQzDUN7sPww9UzFHuAHZHoyfTr27XnJYVUe
rVYrCPq8vqfn//01qz55Xs0hvzGdlTFXhuabFtQnKFH5SNwo/fcznhB7rePOwHoj
xOpXTBepUCIJLbtNnWFTV44t9gh5IqIWtoBReQIDAQABAoIBAHSWipORGp/uKFXj
i/mut776x8ofsAxhnLBARQr93ID+i49W8H7EJGkOfaDjTICYC1dbpGrri61qk8sx
qX7p3v/5NzKwOIfEpirgwVIqSNYe/ncbxnhxkx6tXtUtFKmEx40JskvSpSYAhmmO
1XSx0E/PWaEN/nLgX/f1eWJIlxlQkk3QeqL+FGbCXI48DEtlJ9+MzMu4pAwZTpj5
5qtXo5JJ0jRGfJVPAOznRsYqv864AhMdMIWguzk6EGnbaCWwPcfcn+h9a5LMdony
MDHfBS7bb5tkF3+AfnVY3IBMVx7YlsD9eAyajlgiKu4zLbwTRHjXgShy+4Oussz0
ugNGnkECgYEA/hi+McrZC8C4gg6XqK8+9joD8tnyDZDz88BQB7CZqABUSwvjDqlP
L8hcwo/lzvjBNYGkqaFPUICGWKjeCtd8pPS2DCVXxDQX4aHF1vUur0uYNncJiV3N
XQz4Iemsa6wnKf6M67b5vMXICw7dw0HZCdIHD1hnhdtDz0uVpeevLZ8CgYEA2KCT
Y43lorjrbCgMqtlefkr3GJA9dey+hTzCiWEOOqn9RqGoEGUday0sKhiLofOgmN2B
LEukpKIey8s+Q/cb6lReajDVPDsMweX8i7hz3Wa4Ugp4Xa5BpHqu8qIAE2JUZ7bU
t88aQAYE58pUF+/Lq1QzAQdrjjzQBx6SrBxieecCgYEAvukoPZEC8mmiN1VvbTX+
QFHmlZha3QaDxChB+QUe7bMRojEUL/fVnzkTOLuVFqSfxevaI/km9n0ac5KtAchV
xjp2bTnBb5EUQFqjopYktWA+xO07JRJtMfSEmjZPbbay1kKC7rdTfBm961EIHaRj
xZUf6M+rOE8964oGrdgdLlECgYEA046GQmx6fh7/82FtdZDRQp9tj3SWQUtSiQZc
qhO59Lq8mjUXz+MgBuJXxkiwXRpzlbaFB0Bca1fUoYw8o915SrDYf/Zu2OKGQ/qa
V81sgiVmDuEgycR7YOlbX6OsVUHrUlpwhY3hgfMe6UtkMvhBvHF/WhroBEIJm1pV
PXZ/CbMCgYEApNWVktFBjOaYfY6SNn4iSts1jgsQbbpglg3kT7PLKjCAhI6lNsbk
dyT7ut01PL6RaW4SeQWtrJIVQaM6vF3pprMKqlc5XihOGAmVqH7rQx9rtQB5TicL
BFrwkQE4HQtQBV60hYQUzzlSk44VFDz+jxIEtacRHaomDRh2FtOTz+I=
-----END RSA PRIVATE KEY-----
`
certFile, err := ioutil.TempFile("", "cert")
c.Assert(err, check.IsNil)
defer os.Remove(certFile.Name())
certFile.Write([]byte(cert))
certFile.Close()
keyFile, err := ioutil.TempFile("", "key")
c.Assert(err, check.IsNil)
defer os.Remove(keyFile.Name())
keyFile.Write([]byte(key))
keyFile.Close()
libkv.AddStore("mock", NewMock)
d := &Discovery{backend: "mock"}
err = d.Initialize("127.0.0.3:1234", 0, 0, map[string]string{
"kv.cacertfile": certFile.Name(),
"kv.certfile": certFile.Name(),
"kv.keyfile": keyFile.Name(),
})
c.Assert(err, check.IsNil)
s := d.store.(*Mock)
c.Assert(s.Options.TLS, check.NotNil)
c.Assert(s.Options.TLS.RootCAs, check.NotNil)
c.Assert(s.Options.TLS.Certificates, check.HasLen, 1)
}
func (ds *DiscoverySuite) TestWatch(c *check.C) {
mockCh := make(chan []*store.KVPair)
storeMock := &FakeStore{
Endpoints: []string{"127.0.0.1:1234"},
mockKVChan: mockCh,
}
d := &Discovery{backend: store.CONSUL}
d.Initialize("127.0.0.1:1234/path", 0, 0, nil)
d.store = storeMock
expected := discovery.Entries{
&discovery.Entry{Host: "1.1.1.1", Port: "1111"},
&discovery.Entry{Host: "2.2.2.2", Port: "2222"},
}
kvs := []*store.KVPair{
{Key: path.Join("path", defaultDiscoveryPath, "1.1.1.1"), Value: []byte("1.1.1.1:1111")},
{Key: path.Join("path", defaultDiscoveryPath, "2.2.2.2"), Value: []byte("2.2.2.2:2222")},
}
stopCh := make(chan struct{})
ch, errCh := d.Watch(stopCh)
// It should fire an error since the first WatchTree call failed.
c.Assert(<-errCh, check.ErrorMatches, "test error")
// We have to drain the error channel otherwise Watch will get stuck.
go func() {
for range errCh {
}
}()
// Push the entries into the store channel and make sure discovery emits.
mockCh <- kvs
c.Assert(<-ch, check.DeepEquals, expected)
// Add a new entry.
expected = append(expected, &discovery.Entry{Host: "3.3.3.3", Port: "3333"})
kvs = append(kvs, &store.KVPair{Key: path.Join("path", defaultDiscoveryPath, "3.3.3.3"), Value: []byte("3.3.3.3:3333")})
mockCh <- kvs
c.Assert(<-ch, check.DeepEquals, expected)
close(mockCh)
// Give it enough time to call WatchTree.
time.Sleep(3)
// Stop and make sure it closes all channels.
close(stopCh)
c.Assert(<-ch, check.IsNil)
c.Assert(<-errCh, check.IsNil)
}
// FakeStore implements store.Store methods. It mocks all store
// function in a simple, naive way.
type FakeStore struct {
Endpoints []string
Options *store.Config
mockKVChan <-chan []*store.KVPair
watchTreeCallCount int
}
func (s *FakeStore) Put(key string, value []byte, options *store.WriteOptions) error {
return nil
}
func (s *FakeStore) Get(key string) (*store.KVPair, error) {
return nil, nil
}
func (s *FakeStore) Delete(key string) error {
return nil
}
func (s *FakeStore) Exists(key string) (bool, error) {
return true, nil
}
func (s *FakeStore) Watch(key string, stopCh <-chan struct{}) (<-chan *store.KVPair, error) {
return nil, nil
}
// WatchTree will fail the first time, and return the mockKVchan afterwards.
// This is the behavior we need for testing.. If we need 'moar', should update this.
func (s *FakeStore) WatchTree(directory string, stopCh <-chan struct{}) (<-chan []*store.KVPair, error) {
if s.watchTreeCallCount == 0 {
s.watchTreeCallCount = 1
return nil, errors.New("test error")
}
// First calls error
return s.mockKVChan, nil
}
func (s *FakeStore) NewLock(key string, options *store.LockOptions) (store.Locker, error) {
return nil, nil
}
func (s *FakeStore) List(directory string) ([]*store.KVPair, error) {
return []*store.KVPair{}, nil
}
func (s *FakeStore) DeleteTree(directory string) error {
return nil
}
func (s *FakeStore) AtomicPut(key string, value []byte, previous *store.KVPair, options *store.WriteOptions) (bool, *store.KVPair, error) {
return true, nil, nil
}
func (s *FakeStore) AtomicDelete(key string, previous *store.KVPair) (bool, error) {
return true, nil
}
func (s *FakeStore) Close() {
}

View File

@@ -0,0 +1,83 @@
package memory
import (
"time"
"github.com/hyperhq/hypercli/pkg/discovery"
)
// Discovery implements a descovery backend that keeps
// data in memory.
type Discovery struct {
heartbeat time.Duration
values []string
}
func init() {
Init()
}
// Init registers the memory backend on demand.
func Init() {
discovery.Register("memory", &Discovery{})
}
// Initialize sets the heartbeat for the memory backend.
func (s *Discovery) Initialize(_ string, heartbeat time.Duration, _ time.Duration, _ map[string]string) error {
s.heartbeat = heartbeat
s.values = make([]string, 0)
return nil
}
// Watch sends periodic discovery updates to a channel.
func (s *Discovery) Watch(stopCh <-chan struct{}) (<-chan discovery.Entries, <-chan error) {
ch := make(chan discovery.Entries)
errCh := make(chan error)
ticker := time.NewTicker(s.heartbeat)
go func() {
defer close(errCh)
defer close(ch)
// Send the initial entries if available.
var currentEntries discovery.Entries
if len(s.values) > 0 {
var err error
currentEntries, err = discovery.CreateEntries(s.values)
if err != nil {
errCh <- err
} else {
ch <- currentEntries
}
}
// Periodically send updates.
for {
select {
case <-ticker.C:
newEntries, err := discovery.CreateEntries(s.values)
if err != nil {
errCh <- err
continue
}
// Check if the file has really changed.
if !newEntries.Equals(currentEntries) {
ch <- newEntries
}
currentEntries = newEntries
case <-stopCh:
ticker.Stop()
return
}
}
}()
return ch, errCh
}
// Register adds a new address to the discovery.
func (s *Discovery) Register(addr string) error {
s.values = append(s.values, addr)
return nil
}

View File

@@ -0,0 +1,48 @@
package memory
import (
"testing"
"github.com/hyperhq/hypercli/pkg/discovery"
"github.com/go-check/check"
)
// Hook up gocheck into the "go test" runner.
func Test(t *testing.T) { check.TestingT(t) }
type discoverySuite struct{}
var _ = check.Suite(&discoverySuite{})
func (s *discoverySuite) TestWatch(c *check.C) {
d := &Discovery{}
d.Initialize("foo", 1000, 0, nil)
stopCh := make(chan struct{})
ch, errCh := d.Watch(stopCh)
// We have to drain the error channel otherwise Watch will get stuck.
go func() {
for range errCh {
}
}()
expected := discovery.Entries{
&discovery.Entry{Host: "1.1.1.1", Port: "1111"},
}
c.Assert(d.Register("1.1.1.1:1111"), check.IsNil)
c.Assert(<-ch, check.DeepEquals, expected)
expected = discovery.Entries{
&discovery.Entry{Host: "1.1.1.1", Port: "1111"},
&discovery.Entry{Host: "2.2.2.2", Port: "2222"},
}
c.Assert(d.Register("2.2.2.2:2222"), check.IsNil)
c.Assert(<-ch, check.DeepEquals, expected)
// Stop and make sure it closes all channels.
close(stopCh)
c.Assert(<-ch, check.IsNil)
c.Assert(<-errCh, check.IsNil)
}

View File

@@ -0,0 +1,54 @@
package nodes
import (
"fmt"
"strings"
"time"
"github.com/hyperhq/hypercli/pkg/discovery"
)
// Discovery is exported
type Discovery struct {
entries discovery.Entries
}
func init() {
Init()
}
// Init is exported
func Init() {
discovery.Register("nodes", &Discovery{})
}
// Initialize is exported
func (s *Discovery) Initialize(uris string, _ time.Duration, _ time.Duration, _ map[string]string) error {
for _, input := range strings.Split(uris, ",") {
for _, ip := range discovery.Generate(input) {
entry, err := discovery.NewEntry(ip)
if err != nil {
return fmt.Errorf("%s, please check you are using the correct discovery (missing token:// ?)", err.Error())
}
s.entries = append(s.entries, entry)
}
}
return nil
}
// Watch is exported
func (s *Discovery) Watch(stopCh <-chan struct{}) (<-chan discovery.Entries, <-chan error) {
ch := make(chan discovery.Entries)
go func() {
defer close(ch)
ch <- s.entries
<-stopCh
}()
return ch, nil
}
// Register is exported
func (s *Discovery) Register(addr string) error {
return discovery.ErrNotImplemented
}

View File

@@ -0,0 +1,51 @@
package nodes
import (
"testing"
"github.com/hyperhq/hypercli/pkg/discovery"
"github.com/go-check/check"
)
// Hook up gocheck into the "go test" runner.
func Test(t *testing.T) { check.TestingT(t) }
type DiscoverySuite struct{}
var _ = check.Suite(&DiscoverySuite{})
func (s *DiscoverySuite) TestInitialize(c *check.C) {
d := &Discovery{}
d.Initialize("1.1.1.1:1111,2.2.2.2:2222", 0, 0, nil)
c.Assert(len(d.entries), check.Equals, 2)
c.Assert(d.entries[0].String(), check.Equals, "1.1.1.1:1111")
c.Assert(d.entries[1].String(), check.Equals, "2.2.2.2:2222")
}
func (s *DiscoverySuite) TestInitializeWithPattern(c *check.C) {
d := &Discovery{}
d.Initialize("1.1.1.[1:2]:1111,2.2.2.[2:4]:2222", 0, 0, nil)
c.Assert(len(d.entries), check.Equals, 5)
c.Assert(d.entries[0].String(), check.Equals, "1.1.1.1:1111")
c.Assert(d.entries[1].String(), check.Equals, "1.1.1.2:1111")
c.Assert(d.entries[2].String(), check.Equals, "2.2.2.2:2222")
c.Assert(d.entries[3].String(), check.Equals, "2.2.2.3:2222")
c.Assert(d.entries[4].String(), check.Equals, "2.2.2.4:2222")
}
func (s *DiscoverySuite) TestWatch(c *check.C) {
d := &Discovery{}
d.Initialize("1.1.1.1:1111,2.2.2.2:2222", 0, 0, nil)
expected := discovery.Entries{
&discovery.Entry{Host: "1.1.1.1", Port: "1111"},
&discovery.Entry{Host: "2.2.2.2", Port: "2222"},
}
ch, _ := d.Watch(nil)
c.Assert(expected.Equals(<-ch), check.Equals, true)
}
func (s *DiscoverySuite) TestRegister(c *check.C) {
d := &Discovery{}
c.Assert(d.Register("0.0.0.0"), check.NotNil)
}

View File

@@ -0,0 +1,40 @@
// Package filenotify provides a mechanism for watching file(s) for changes.
// Generally leans on fsnotify, but provides a poll-based notifier which fsnotify does not support.
// These are wrapped up in a common interface so that either can be used interchangeably in your code.
package filenotify
import "gopkg.in/fsnotify.v1"
// FileWatcher is an interface for implementing file notification watchers
type FileWatcher interface {
Events() <-chan fsnotify.Event
Errors() <-chan error
Add(name string) error
Remove(name string) error
Close() error
}
// New tries to use an fs-event watcher, and falls back to the poller if there is an error
func New() (FileWatcher, error) {
if watcher, err := NewEventWatcher(); err == nil {
return watcher, nil
}
return NewPollingWatcher(), nil
}
// NewPollingWatcher returns a poll-based file watcher
func NewPollingWatcher() FileWatcher {
return &filePoller{
events: make(chan fsnotify.Event),
errors: make(chan error),
}
}
// NewEventWatcher returns an fs-event based file watcher
func NewEventWatcher() (FileWatcher, error) {
watcher, err := fsnotify.NewWatcher()
if err != nil {
return nil, err
}
return &fsNotifyWatcher{watcher}, nil
}

View File

@@ -0,0 +1,18 @@
package filenotify
import "gopkg.in/fsnotify.v1"
// fsNotify wraps the fsnotify package to satisfy the FileNotifer interface
type fsNotifyWatcher struct {
*fsnotify.Watcher
}
// GetEvents returns the fsnotify event channel receiver
func (w *fsNotifyWatcher) Events() <-chan fsnotify.Event {
return w.Watcher.Events
}
// GetErrors returns the fsnotify error channel receiver
func (w *fsNotifyWatcher) Errors() <-chan error {
return w.Watcher.Errors
}

View File

@@ -0,0 +1,205 @@
package filenotify
import (
"errors"
"fmt"
"os"
"sync"
"time"
"github.com/Sirupsen/logrus"
"gopkg.in/fsnotify.v1"
)
var (
// errPollerClosed is returned when the poller is closed
errPollerClosed = errors.New("poller is closed")
// errNoSuchPoller is returned when trying to remove a watch that doesn't exist
errNoSuchWatch = errors.New("poller does not exist")
)
// watchWaitTime is the time to wait between file poll loops
const watchWaitTime = 200 * time.Millisecond
// filePoller is used to poll files for changes, especially in cases where fsnotify
// can't be run (e.g. when inotify handles are exhausted)
// filePoller satisfies the FileWatcher interface
type filePoller struct {
// watches is the list of files currently being polled, close the associated channel to stop the watch
watches map[string]chan struct{}
// events is the channel to listen to for watch events
events chan fsnotify.Event
// errors is the channel to listen to for watch errors
errors chan error
// mu locks the poller for modification
mu sync.Mutex
// closed is used to specify when the poller has already closed
closed bool
}
// Add adds a filename to the list of watches
// once added the file is polled for changes in a separate goroutine
func (w *filePoller) Add(name string) error {
w.mu.Lock()
defer w.mu.Unlock()
if w.closed == true {
return errPollerClosed
}
f, err := os.Open(name)
if err != nil {
return err
}
fi, err := os.Stat(name)
if err != nil {
return err
}
if w.watches == nil {
w.watches = make(map[string]chan struct{})
}
if _, exists := w.watches[name]; exists {
return fmt.Errorf("watch exists")
}
chClose := make(chan struct{})
w.watches[name] = chClose
go w.watch(f, fi, chClose)
return nil
}
// Remove stops and removes watch with the specified name
func (w *filePoller) Remove(name string) error {
w.mu.Lock()
defer w.mu.Unlock()
return w.remove(name)
}
func (w *filePoller) remove(name string) error {
if w.closed == true {
return errPollerClosed
}
chClose, exists := w.watches[name]
if !exists {
return errNoSuchWatch
}
close(chClose)
delete(w.watches, name)
return nil
}
// Events returns the event channel
// This is used for notifications on events about watched files
func (w *filePoller) Events() <-chan fsnotify.Event {
return w.events
}
// Errors returns the errors channel
// This is used for notifications about errors on watched files
func (w *filePoller) Errors() <-chan error {
return w.errors
}
// Close closes the poller
// All watches are stopped, removed, and the poller cannot be added to
func (w *filePoller) Close() error {
w.mu.Lock()
defer w.mu.Unlock()
if w.closed {
return nil
}
w.closed = true
for name := range w.watches {
w.remove(name)
delete(w.watches, name)
}
close(w.events)
close(w.errors)
return nil
}
// sendEvent publishes the specified event to the events channel
func (w *filePoller) sendEvent(e fsnotify.Event, chClose <-chan struct{}) error {
select {
case w.events <- e:
case <-chClose:
return fmt.Errorf("closed")
}
return nil
}
// sendErr publishes the specified error to the errors channel
func (w *filePoller) sendErr(e error, chClose <-chan struct{}) error {
select {
case w.errors <- e:
case <-chClose:
return fmt.Errorf("closed")
}
return nil
}
// watch is responsible for polling the specified file for changes
// upon finding changes to a file or errors, sendEvent/sendErr is called
func (w *filePoller) watch(f *os.File, lastFi os.FileInfo, chClose chan struct{}) {
for {
time.Sleep(watchWaitTime)
select {
case <-chClose:
logrus.Debugf("watch for %s closed", f.Name())
return
default:
}
fi, err := os.Stat(f.Name())
if err != nil {
// if we got an error here and lastFi is not set, we can presume that nothing has changed
// This should be safe since before `watch()` is called, a stat is performed, there is any error `watch` is not called
if lastFi == nil {
continue
}
// If it doesn't exist at this point, it must have been removed
// no need to send the error here since this is a valid operation
if os.IsNotExist(err) {
if err := w.sendEvent(fsnotify.Event{Op: fsnotify.Remove, Name: f.Name()}, chClose); err != nil {
return
}
lastFi = nil
continue
}
// at this point, send the error
if err := w.sendErr(err, chClose); err != nil {
return
}
continue
}
if lastFi == nil {
if err := w.sendEvent(fsnotify.Event{Op: fsnotify.Create, Name: fi.Name()}, chClose); err != nil {
return
}
lastFi = fi
continue
}
if fi.Mode() != lastFi.Mode() {
if err := w.sendEvent(fsnotify.Event{Op: fsnotify.Chmod, Name: fi.Name()}, chClose); err != nil {
return
}
lastFi = fi
continue
}
if fi.ModTime() != lastFi.ModTime() || fi.Size() != lastFi.Size() {
if err := w.sendEvent(fsnotify.Event{Op: fsnotify.Write, Name: fi.Name()}, chClose); err != nil {
return
}
lastFi = fi
continue
}
}
}

View File

@@ -0,0 +1,133 @@
package filenotify
import (
"fmt"
"io/ioutil"
"os"
"testing"
"time"
"gopkg.in/fsnotify.v1"
)
func TestPollerAddRemove(t *testing.T) {
w := NewPollingWatcher()
if err := w.Add("no-such-file"); err == nil {
t.Fatal("should have gotten error when adding a non-existent file")
}
if err := w.Remove("no-such-file"); err == nil {
t.Fatal("should have gotten error when removing non-existent watch")
}
f, err := ioutil.TempFile("", "asdf")
if err != nil {
t.Fatal(err)
}
defer os.RemoveAll(f.Name())
if err := w.Add(f.Name()); err != nil {
t.Fatal(err)
}
if err := w.Remove(f.Name()); err != nil {
t.Fatal(err)
}
}
func TestPollerEvent(t *testing.T) {
w := NewPollingWatcher()
f, err := ioutil.TempFile("", "test-poller")
if err != nil {
t.Fatal("error creating temp file")
}
defer os.RemoveAll(f.Name())
f.Close()
if err := w.Add(f.Name()); err != nil {
t.Fatal(err)
}
select {
case <-w.Events():
t.Fatal("got event before anything happened")
case <-w.Errors():
t.Fatal("got error before anything happened")
default:
}
if err := ioutil.WriteFile(f.Name(), []byte("hello"), 644); err != nil {
t.Fatal(err)
}
if err := assertEvent(w, fsnotify.Write); err != nil {
t.Fatal(err)
}
if err := os.Chmod(f.Name(), 600); err != nil {
t.Fatal(err)
}
if err := assertEvent(w, fsnotify.Chmod); err != nil {
t.Fatal(err)
}
if err := os.Remove(f.Name()); err != nil {
t.Fatal(err)
}
if err := assertEvent(w, fsnotify.Remove); err != nil {
t.Fatal(err)
}
}
func TestPollerClose(t *testing.T) {
w := NewPollingWatcher()
if err := w.Close(); err != nil {
t.Fatal(err)
}
// test double-close
if err := w.Close(); err != nil {
t.Fatal(err)
}
select {
case _, open := <-w.Events():
if open {
t.Fatal("event chan should be closed")
}
default:
t.Fatal("event chan should be closed")
}
select {
case _, open := <-w.Errors():
if open {
t.Fatal("errors chan should be closed")
}
default:
t.Fatal("errors chan should be closed")
}
f, err := ioutil.TempFile("", "asdf")
if err != nil {
t.Fatal(err)
}
defer os.RemoveAll(f.Name())
if err := w.Add(f.Name()); err == nil {
t.Fatal("should have gotten error adding watch for closed watcher")
}
}
func assertEvent(w FileWatcher, eType fsnotify.Op) error {
var err error
select {
case e := <-w.Events():
if e.Op != eType {
err = fmt.Errorf("got wrong event type, expected %q: %v", eType, e)
}
case e := <-w.Errors():
err = fmt.Errorf("got unexpected error waiting for events %v: %v", eType, e)
case <-time.After(watchWaitTime * 3):
err = fmt.Errorf("timeout waiting for event %v", eType)
}
return err
}

View File

@@ -0,0 +1,279 @@
package fileutils
import (
"errors"
"fmt"
"io"
"os"
"path/filepath"
"regexp"
"strings"
"text/scanner"
"github.com/Sirupsen/logrus"
)
// exclusion return true if the specified pattern is an exclusion
func exclusion(pattern string) bool {
return pattern[0] == '!'
}
// empty return true if the specified pattern is empty
func empty(pattern string) bool {
return pattern == ""
}
// CleanPatterns takes a slice of patterns returns a new
// slice of patterns cleaned with filepath.Clean, stripped
// of any empty patterns and lets the caller know whether the
// slice contains any exception patterns (prefixed with !).
func CleanPatterns(patterns []string) ([]string, [][]string, bool, error) {
// Loop over exclusion patterns and:
// 1. Clean them up.
// 2. Indicate whether we are dealing with any exception rules.
// 3. Error if we see a single exclusion marker on it's own (!).
cleanedPatterns := []string{}
patternDirs := [][]string{}
exceptions := false
for _, pattern := range patterns {
// Eliminate leading and trailing whitespace.
pattern = strings.TrimSpace(pattern)
if empty(pattern) {
continue
}
if exclusion(pattern) {
if len(pattern) == 1 {
return nil, nil, false, errors.New("Illegal exclusion pattern: !")
}
exceptions = true
}
pattern = filepath.Clean(pattern)
cleanedPatterns = append(cleanedPatterns, pattern)
if exclusion(pattern) {
pattern = pattern[1:]
}
patternDirs = append(patternDirs, strings.Split(pattern, "/"))
}
return cleanedPatterns, patternDirs, exceptions, nil
}
// Matches returns true if file matches any of the patterns
// and isn't excluded by any of the subsequent patterns.
func Matches(file string, patterns []string) (bool, error) {
file = filepath.Clean(file)
if file == "." {
// Don't let them exclude everything, kind of silly.
return false, nil
}
patterns, patDirs, _, err := CleanPatterns(patterns)
if err != nil {
return false, err
}
return OptimizedMatches(file, patterns, patDirs)
}
// OptimizedMatches is basically the same as fileutils.Matches() but optimized for archive.go.
// It will assume that the inputs have been preprocessed and therefore the function
// doesn't need to do as much error checking and clean-up. This was done to avoid
// repeating these steps on each file being checked during the archive process.
// The more generic fileutils.Matches() can't make these assumptions.
func OptimizedMatches(file string, patterns []string, patDirs [][]string) (bool, error) {
matched := false
parentPath := filepath.Dir(file)
parentPathDirs := strings.Split(parentPath, "/")
for i, pattern := range patterns {
negative := false
if exclusion(pattern) {
negative = true
pattern = pattern[1:]
}
match, err := regexpMatch(pattern, file)
if err != nil {
return false, fmt.Errorf("Error in pattern (%s): %s", pattern, err)
}
if !match && parentPath != "." {
// Check to see if the pattern matches one of our parent dirs.
if len(patDirs[i]) <= len(parentPathDirs) {
match, _ = regexpMatch(strings.Join(patDirs[i], "/"),
strings.Join(parentPathDirs[:len(patDirs[i])], "/"))
}
}
if match {
matched = !negative
}
}
if matched {
logrus.Debugf("Skipping excluded path: %s", file)
}
return matched, nil
}
// regexpMatch tries to match the logic of filepath.Match but
// does so using regexp logic. We do this so that we can expand the
// wildcard set to include other things, like "**" to mean any number
// of directories. This means that we should be backwards compatible
// with filepath.Match(). We'll end up supporting more stuff, due to
// the fact that we're using regexp, but that's ok - it does no harm.
func regexpMatch(pattern, path string) (bool, error) {
regStr := "^"
// Do some syntax checking on the pattern.
// filepath's Match() has some really weird rules that are inconsistent
// so instead of trying to dup their logic, just call Match() for its
// error state and if there is an error in the pattern return it.
// If this becomes an issue we can remove this since its really only
// needed in the error (syntax) case - which isn't really critical.
if _, err := filepath.Match(pattern, path); err != nil {
return false, err
}
// Go through the pattern and convert it to a regexp.
// We use a scanner so we can support utf-8 chars.
var scan scanner.Scanner
scan.Init(strings.NewReader(pattern))
sl := string(os.PathSeparator)
escSL := sl
if sl == `\` {
escSL += `\`
}
for scan.Peek() != scanner.EOF {
ch := scan.Next()
if ch == '*' {
if scan.Peek() == '*' {
// is some flavor of "**"
scan.Next()
if scan.Peek() == scanner.EOF {
// is "**EOF" - to align with .gitignore just accept all
regStr += ".*"
} else {
// is "**"
regStr += "((.*" + escSL + ")|([^" + escSL + "]*))"
}
// Treat **/ as ** so eat the "/"
if string(scan.Peek()) == sl {
scan.Next()
}
} else {
// is "*" so map it to anything but "/"
regStr += "[^" + escSL + "]*"
}
} else if ch == '?' {
// "?" is any char except "/"
regStr += "[^" + escSL + "]"
} else if strings.Index(".$", string(ch)) != -1 {
// Escape some regexp special chars that have no meaning
// in golang's filepath.Match
regStr += `\` + string(ch)
} else if ch == '\\' {
// escape next char. Note that a trailing \ in the pattern
// will be left alone (but need to escape it)
if sl == `\` {
// On windows map "\" to "\\", meaning an escaped backslash,
// and then just continue because filepath.Match on
// Windows doesn't allow escaping at all
regStr += escSL
continue
}
if scan.Peek() != scanner.EOF {
regStr += `\` + string(scan.Next())
} else {
regStr += `\`
}
} else {
regStr += string(ch)
}
}
regStr += "$"
res, err := regexp.MatchString(regStr, path)
// Map regexp's error to filepath's so no one knows we're not using filepath
if err != nil {
err = filepath.ErrBadPattern
}
return res, err
}
// CopyFile copies from src to dst until either EOF is reached
// on src or an error occurs. It verifies src exists and remove
// the dst if it exists.
func CopyFile(src, dst string) (int64, error) {
cleanSrc := filepath.Clean(src)
cleanDst := filepath.Clean(dst)
if cleanSrc == cleanDst {
return 0, nil
}
sf, err := os.Open(cleanSrc)
if err != nil {
return 0, err
}
defer sf.Close()
if err := os.Remove(cleanDst); err != nil && !os.IsNotExist(err) {
return 0, err
}
df, err := os.Create(cleanDst)
if err != nil {
return 0, err
}
defer df.Close()
return io.Copy(df, sf)
}
// ReadSymlinkedDirectory returns the target directory of a symlink.
// The target of the symbolic link may not be a file.
func ReadSymlinkedDirectory(path string) (string, error) {
var realPath string
var err error
if realPath, err = filepath.Abs(path); err != nil {
return "", fmt.Errorf("unable to get absolute path for %s: %s", path, err)
}
if realPath, err = filepath.EvalSymlinks(realPath); err != nil {
return "", fmt.Errorf("failed to canonicalise path for %s: %s", path, err)
}
realPathInfo, err := os.Stat(realPath)
if err != nil {
return "", fmt.Errorf("failed to stat target '%s' of '%s': %s", realPath, path, err)
}
if !realPathInfo.Mode().IsDir() {
return "", fmt.Errorf("canonical path points to a file '%s'", realPath)
}
return realPath, nil
}
// CreateIfNotExists creates a file or a directory only if it does not already exist.
func CreateIfNotExists(path string, isDir bool) error {
if _, err := os.Stat(path); err != nil {
if os.IsNotExist(err) {
if isDir {
return os.MkdirAll(path, 0755)
}
if err := os.MkdirAll(filepath.Dir(path), 0755); err != nil {
return err
}
f, err := os.OpenFile(path, os.O_CREATE, 0755)
if err != nil {
return err
}
f.Close()
}
}
return nil
}

View File

@@ -0,0 +1,573 @@
package fileutils
import (
"io/ioutil"
"os"
"path"
"path/filepath"
"runtime"
"strings"
"testing"
)
// CopyFile with invalid src
func TestCopyFileWithInvalidSrc(t *testing.T) {
tempFolder, err := ioutil.TempDir("", "docker-fileutils-test")
defer os.RemoveAll(tempFolder)
if err != nil {
t.Fatal(err)
}
bytes, err := CopyFile("/invalid/file/path", path.Join(tempFolder, "dest"))
if err == nil {
t.Fatal("Should have fail to copy an invalid src file")
}
if bytes != 0 {
t.Fatal("Should have written 0 bytes")
}
}
// CopyFile with invalid dest
func TestCopyFileWithInvalidDest(t *testing.T) {
tempFolder, err := ioutil.TempDir("", "docker-fileutils-test")
defer os.RemoveAll(tempFolder)
if err != nil {
t.Fatal(err)
}
src := path.Join(tempFolder, "file")
err = ioutil.WriteFile(src, []byte("content"), 0740)
if err != nil {
t.Fatal(err)
}
bytes, err := CopyFile(src, path.Join(tempFolder, "/invalid/dest/path"))
if err == nil {
t.Fatal("Should have fail to copy an invalid src file")
}
if bytes != 0 {
t.Fatal("Should have written 0 bytes")
}
}
// CopyFile with same src and dest
func TestCopyFileWithSameSrcAndDest(t *testing.T) {
tempFolder, err := ioutil.TempDir("", "docker-fileutils-test")
defer os.RemoveAll(tempFolder)
if err != nil {
t.Fatal(err)
}
file := path.Join(tempFolder, "file")
err = ioutil.WriteFile(file, []byte("content"), 0740)
if err != nil {
t.Fatal(err)
}
bytes, err := CopyFile(file, file)
if err != nil {
t.Fatal(err)
}
if bytes != 0 {
t.Fatal("Should have written 0 bytes as it is the same file.")
}
}
// CopyFile with same src and dest but path is different and not clean
func TestCopyFileWithSameSrcAndDestWithPathNameDifferent(t *testing.T) {
tempFolder, err := ioutil.TempDir("", "docker-fileutils-test")
defer os.RemoveAll(tempFolder)
if err != nil {
t.Fatal(err)
}
testFolder := path.Join(tempFolder, "test")
err = os.MkdirAll(testFolder, 0740)
if err != nil {
t.Fatal(err)
}
file := path.Join(testFolder, "file")
sameFile := testFolder + "/../test/file"
err = ioutil.WriteFile(file, []byte("content"), 0740)
if err != nil {
t.Fatal(err)
}
bytes, err := CopyFile(file, sameFile)
if err != nil {
t.Fatal(err)
}
if bytes != 0 {
t.Fatal("Should have written 0 bytes as it is the same file.")
}
}
func TestCopyFile(t *testing.T) {
tempFolder, err := ioutil.TempDir("", "docker-fileutils-test")
defer os.RemoveAll(tempFolder)
if err != nil {
t.Fatal(err)
}
src := path.Join(tempFolder, "src")
dest := path.Join(tempFolder, "dest")
ioutil.WriteFile(src, []byte("content"), 0777)
ioutil.WriteFile(dest, []byte("destContent"), 0777)
bytes, err := CopyFile(src, dest)
if err != nil {
t.Fatal(err)
}
if bytes != 7 {
t.Fatalf("Should have written %d bytes but wrote %d", 7, bytes)
}
actual, err := ioutil.ReadFile(dest)
if err != nil {
t.Fatal(err)
}
if string(actual) != "content" {
t.Fatalf("Dest content was '%s', expected '%s'", string(actual), "content")
}
}
// Reading a symlink to a directory must return the directory
func TestReadSymlinkedDirectoryExistingDirectory(t *testing.T) {
var err error
if err = os.Mkdir("/tmp/testReadSymlinkToExistingDirectory", 0777); err != nil {
t.Errorf("failed to create directory: %s", err)
}
if err = os.Symlink("/tmp/testReadSymlinkToExistingDirectory", "/tmp/dirLinkTest"); err != nil {
t.Errorf("failed to create symlink: %s", err)
}
var path string
if path, err = ReadSymlinkedDirectory("/tmp/dirLinkTest"); err != nil {
t.Fatalf("failed to read symlink to directory: %s", err)
}
if path != "/tmp/testReadSymlinkToExistingDirectory" {
t.Fatalf("symlink returned unexpected directory: %s", path)
}
if err = os.Remove("/tmp/testReadSymlinkToExistingDirectory"); err != nil {
t.Errorf("failed to remove temporary directory: %s", err)
}
if err = os.Remove("/tmp/dirLinkTest"); err != nil {
t.Errorf("failed to remove symlink: %s", err)
}
}
// Reading a non-existing symlink must fail
func TestReadSymlinkedDirectoryNonExistingSymlink(t *testing.T) {
var path string
var err error
if path, err = ReadSymlinkedDirectory("/tmp/test/foo/Non/ExistingPath"); err == nil {
t.Fatalf("error expected for non-existing symlink")
}
if path != "" {
t.Fatalf("expected empty path, but '%s' was returned", path)
}
}
// Reading a symlink to a file must fail
func TestReadSymlinkedDirectoryToFile(t *testing.T) {
var err error
var file *os.File
if file, err = os.Create("/tmp/testReadSymlinkToFile"); err != nil {
t.Fatalf("failed to create file: %s", err)
}
file.Close()
if err = os.Symlink("/tmp/testReadSymlinkToFile", "/tmp/fileLinkTest"); err != nil {
t.Errorf("failed to create symlink: %s", err)
}
var path string
if path, err = ReadSymlinkedDirectory("/tmp/fileLinkTest"); err == nil {
t.Fatalf("ReadSymlinkedDirectory on a symlink to a file should've failed")
}
if path != "" {
t.Fatalf("path should've been empty: %s", path)
}
if err = os.Remove("/tmp/testReadSymlinkToFile"); err != nil {
t.Errorf("failed to remove file: %s", err)
}
if err = os.Remove("/tmp/fileLinkTest"); err != nil {
t.Errorf("failed to remove symlink: %s", err)
}
}
func TestWildcardMatches(t *testing.T) {
match, _ := Matches("fileutils.go", []string{"*"})
if match != true {
t.Errorf("failed to get a wildcard match, got %v", match)
}
}
// A simple pattern match should return true.
func TestPatternMatches(t *testing.T) {
match, _ := Matches("fileutils.go", []string{"*.go"})
if match != true {
t.Errorf("failed to get a match, got %v", match)
}
}
// An exclusion followed by an inclusion should return true.
func TestExclusionPatternMatchesPatternBefore(t *testing.T) {
match, _ := Matches("fileutils.go", []string{"!fileutils.go", "*.go"})
if match != true {
t.Errorf("failed to get true match on exclusion pattern, got %v", match)
}
}
// A folder pattern followed by an exception should return false.
func TestPatternMatchesFolderExclusions(t *testing.T) {
match, _ := Matches("docs/README.md", []string{"docs", "!docs/README.md"})
if match != false {
t.Errorf("failed to get a false match on exclusion pattern, got %v", match)
}
}
// A folder pattern followed by an exception should return false.
func TestPatternMatchesFolderWithSlashExclusions(t *testing.T) {
match, _ := Matches("docs/README.md", []string{"docs/", "!docs/README.md"})
if match != false {
t.Errorf("failed to get a false match on exclusion pattern, got %v", match)
}
}
// A folder pattern followed by an exception should return false.
func TestPatternMatchesFolderWildcardExclusions(t *testing.T) {
match, _ := Matches("docs/README.md", []string{"docs/*", "!docs/README.md"})
if match != false {
t.Errorf("failed to get a false match on exclusion pattern, got %v", match)
}
}
// A pattern followed by an exclusion should return false.
func TestExclusionPatternMatchesPatternAfter(t *testing.T) {
match, _ := Matches("fileutils.go", []string{"*.go", "!fileutils.go"})
if match != false {
t.Errorf("failed to get false match on exclusion pattern, got %v", match)
}
}
// A filename evaluating to . should return false.
func TestExclusionPatternMatchesWholeDirectory(t *testing.T) {
match, _ := Matches(".", []string{"*.go"})
if match != false {
t.Errorf("failed to get false match on ., got %v", match)
}
}
// A single ! pattern should return an error.
func TestSingleExclamationError(t *testing.T) {
_, err := Matches("fileutils.go", []string{"!"})
if err == nil {
t.Errorf("failed to get an error for a single exclamation point, got %v", err)
}
}
// A string preceded with a ! should return true from Exclusion.
func TestExclusion(t *testing.T) {
exclusion := exclusion("!")
if !exclusion {
t.Errorf("failed to get true for a single !, got %v", exclusion)
}
}
// Matches with no patterns
func TestMatchesWithNoPatterns(t *testing.T) {
matches, err := Matches("/any/path/there", []string{})
if err != nil {
t.Fatal(err)
}
if matches {
t.Fatalf("Should not have match anything")
}
}
// Matches with malformed patterns
func TestMatchesWithMalformedPatterns(t *testing.T) {
matches, err := Matches("/any/path/there", []string{"["})
if err == nil {
t.Fatal("Should have failed because of a malformed syntax in the pattern")
}
if matches {
t.Fatalf("Should not have match anything")
}
}
// Test lots of variants of patterns & strings
func TestMatches(t *testing.T) {
tests := []struct {
pattern string
text string
pass bool
}{
{"**", "file", true},
{"**", "file/", true},
{"**/", "file", true}, // weird one
{"**/", "file/", true},
{"**", "/", true},
{"**/", "/", true},
{"**", "dir/file", true},
{"**/", "dir/file", false},
{"**", "dir/file/", true},
{"**/", "dir/file/", true},
{"**/**", "dir/file", true},
{"**/**", "dir/file/", true},
{"dir/**", "dir/file", true},
{"dir/**", "dir/file/", true},
{"dir/**", "dir/dir2/file", true},
{"dir/**", "dir/dir2/file/", true},
{"**/dir2/*", "dir/dir2/file", true},
{"**/dir2/*", "dir/dir2/file/", false},
{"**/dir2/**", "dir/dir2/dir3/file", true},
{"**/dir2/**", "dir/dir2/dir3/file/", true},
{"**file", "file", true},
{"**file", "dir/file", true},
{"**/file", "dir/file", true},
{"**file", "dir/dir/file", true},
{"**/file", "dir/dir/file", true},
{"**/file*", "dir/dir/file", true},
{"**/file*", "dir/dir/file.txt", true},
{"**/file*txt", "dir/dir/file.txt", true},
{"**/file*.txt", "dir/dir/file.txt", true},
{"**/file*.txt*", "dir/dir/file.txt", true},
{"**/**/*.txt", "dir/dir/file.txt", true},
{"**/**/*.txt2", "dir/dir/file.txt", false},
{"**/*.txt", "file.txt", true},
{"**/**/*.txt", "file.txt", true},
{"a**/*.txt", "a/file.txt", true},
{"a**/*.txt", "a/dir/file.txt", true},
{"a**/*.txt", "a/dir/dir/file.txt", true},
{"a/*.txt", "a/dir/file.txt", false},
{"a/*.txt", "a/file.txt", true},
{"a/*.txt**", "a/file.txt", true},
{"a[b-d]e", "ae", false},
{"a[b-d]e", "ace", true},
{"a[b-d]e", "aae", false},
{"a[^b-d]e", "aze", true},
{".*", ".foo", true},
{".*", "foo", false},
{"abc.def", "abcdef", false},
{"abc.def", "abc.def", true},
{"abc.def", "abcZdef", false},
{"abc?def", "abcZdef", true},
{"abc?def", "abcdef", false},
{"a\\*b", "a*b", true},
{"a\\", "a", false},
{"a\\", "a\\", false},
{"a\\\\", "a\\", true},
{"**/foo/bar", "foo/bar", true},
{"**/foo/bar", "dir/foo/bar", true},
{"**/foo/bar", "dir/dir2/foo/bar", true},
{"abc/**", "abc", false},
{"abc/**", "abc/def", true},
{"abc/**", "abc/def/ghi", true},
}
for _, test := range tests {
res, _ := regexpMatch(test.pattern, test.text)
if res != test.pass {
t.Fatalf("Failed: %v - res:%v", test, res)
}
}
}
// An empty string should return true from Empty.
func TestEmpty(t *testing.T) {
empty := empty("")
if !empty {
t.Errorf("failed to get true for an empty string, got %v", empty)
}
}
func TestCleanPatterns(t *testing.T) {
cleaned, _, _, _ := CleanPatterns([]string{"docs", "config"})
if len(cleaned) != 2 {
t.Errorf("expected 2 element slice, got %v", len(cleaned))
}
}
func TestCleanPatternsStripEmptyPatterns(t *testing.T) {
cleaned, _, _, _ := CleanPatterns([]string{"docs", "config", ""})
if len(cleaned) != 2 {
t.Errorf("expected 2 element slice, got %v", len(cleaned))
}
}
func TestCleanPatternsExceptionFlag(t *testing.T) {
_, _, exceptions, _ := CleanPatterns([]string{"docs", "!docs/README.md"})
if !exceptions {
t.Errorf("expected exceptions to be true, got %v", exceptions)
}
}
func TestCleanPatternsLeadingSpaceTrimmed(t *testing.T) {
_, _, exceptions, _ := CleanPatterns([]string{"docs", " !docs/README.md"})
if !exceptions {
t.Errorf("expected exceptions to be true, got %v", exceptions)
}
}
func TestCleanPatternsTrailingSpaceTrimmed(t *testing.T) {
_, _, exceptions, _ := CleanPatterns([]string{"docs", "!docs/README.md "})
if !exceptions {
t.Errorf("expected exceptions to be true, got %v", exceptions)
}
}
func TestCleanPatternsErrorSingleException(t *testing.T) {
_, _, _, err := CleanPatterns([]string{"!"})
if err == nil {
t.Errorf("expected error on single exclamation point, got %v", err)
}
}
func TestCleanPatternsFolderSplit(t *testing.T) {
_, dirs, _, _ := CleanPatterns([]string{"docs/config/CONFIG.md"})
if dirs[0][0] != "docs" {
t.Errorf("expected first element in dirs slice to be docs, got %v", dirs[0][1])
}
if dirs[0][1] != "config" {
t.Errorf("expected first element in dirs slice to be config, got %v", dirs[0][1])
}
}
func TestCreateIfNotExistsDir(t *testing.T) {
tempFolder, err := ioutil.TempDir("", "docker-fileutils-test")
if err != nil {
t.Fatal(err)
}
defer os.RemoveAll(tempFolder)
folderToCreate := filepath.Join(tempFolder, "tocreate")
if err := CreateIfNotExists(folderToCreate, true); err != nil {
t.Fatal(err)
}
fileinfo, err := os.Stat(folderToCreate)
if err != nil {
t.Fatalf("Should have create a folder, got %v", err)
}
if !fileinfo.IsDir() {
t.Fatalf("Should have been a dir, seems it's not")
}
}
func TestCreateIfNotExistsFile(t *testing.T) {
tempFolder, err := ioutil.TempDir("", "docker-fileutils-test")
if err != nil {
t.Fatal(err)
}
defer os.RemoveAll(tempFolder)
fileToCreate := filepath.Join(tempFolder, "file/to/create")
if err := CreateIfNotExists(fileToCreate, false); err != nil {
t.Fatal(err)
}
fileinfo, err := os.Stat(fileToCreate)
if err != nil {
t.Fatalf("Should have create a file, got %v", err)
}
if fileinfo.IsDir() {
t.Fatalf("Should have been a file, seems it's not")
}
}
// These matchTests are stolen from go's filepath Match tests.
type matchTest struct {
pattern, s string
match bool
err error
}
var matchTests = []matchTest{
{"abc", "abc", true, nil},
{"*", "abc", true, nil},
{"*c", "abc", true, nil},
{"a*", "a", true, nil},
{"a*", "abc", true, nil},
{"a*", "ab/c", false, nil},
{"a*/b", "abc/b", true, nil},
{"a*/b", "a/c/b", false, nil},
{"a*b*c*d*e*/f", "axbxcxdxe/f", true, nil},
{"a*b*c*d*e*/f", "axbxcxdxexxx/f", true, nil},
{"a*b*c*d*e*/f", "axbxcxdxe/xxx/f", false, nil},
{"a*b*c*d*e*/f", "axbxcxdxexxx/fff", false, nil},
{"a*b?c*x", "abxbbxdbxebxczzx", true, nil},
{"a*b?c*x", "abxbbxdbxebxczzy", false, nil},
{"ab[c]", "abc", true, nil},
{"ab[b-d]", "abc", true, nil},
{"ab[e-g]", "abc", false, nil},
{"ab[^c]", "abc", false, nil},
{"ab[^b-d]", "abc", false, nil},
{"ab[^e-g]", "abc", true, nil},
{"a\\*b", "a*b", true, nil},
{"a\\*b", "ab", false, nil},
{"a?b", "a☺b", true, nil},
{"a[^a]b", "a☺b", true, nil},
{"a???b", "a☺b", false, nil},
{"a[^a][^a][^a]b", "a☺b", false, nil},
{"[a-ζ]*", "α", true, nil},
{"*[a-ζ]", "A", false, nil},
{"a?b", "a/b", false, nil},
{"a*b", "a/b", false, nil},
{"[\\]a]", "]", true, nil},
{"[\\-]", "-", true, nil},
{"[x\\-]", "x", true, nil},
{"[x\\-]", "-", true, nil},
{"[x\\-]", "z", false, nil},
{"[\\-x]", "x", true, nil},
{"[\\-x]", "-", true, nil},
{"[\\-x]", "a", false, nil},
{"[]a]", "]", false, filepath.ErrBadPattern},
{"[-]", "-", false, filepath.ErrBadPattern},
{"[x-]", "x", false, filepath.ErrBadPattern},
{"[x-]", "-", false, filepath.ErrBadPattern},
{"[x-]", "z", false, filepath.ErrBadPattern},
{"[-x]", "x", false, filepath.ErrBadPattern},
{"[-x]", "-", false, filepath.ErrBadPattern},
{"[-x]", "a", false, filepath.ErrBadPattern},
{"\\", "a", false, filepath.ErrBadPattern},
{"[a-b-c]", "a", false, filepath.ErrBadPattern},
{"[", "a", false, filepath.ErrBadPattern},
{"[^", "a", false, filepath.ErrBadPattern},
{"[^bc", "a", false, filepath.ErrBadPattern},
{"a[", "a", false, filepath.ErrBadPattern}, // was nil but IMO its wrong
{"a[", "ab", false, filepath.ErrBadPattern},
{"*x", "xxx", true, nil},
}
func errp(e error) string {
if e == nil {
return "<nil>"
}
return e.Error()
}
// TestMatch test's our version of filepath.Match, called regexpMatch.
func TestMatch(t *testing.T) {
for _, tt := range matchTests {
pattern := tt.pattern
s := tt.s
if runtime.GOOS == "windows" {
if strings.Index(pattern, "\\") >= 0 {
// no escape allowed on windows.
continue
}
pattern = filepath.Clean(pattern)
s = filepath.Clean(s)
}
ok, err := regexpMatch(pattern, s)
if ok != tt.match || err != tt.err {
t.Fatalf("Match(%#q, %#q) = %v, %q want %v, %q", pattern, s, ok, errp(err), tt.match, errp(tt.err))
}
}
}

View File

@@ -0,0 +1,22 @@
// +build linux freebsd
package fileutils
import (
"fmt"
"io/ioutil"
"os"
"github.com/Sirupsen/logrus"
)
// GetTotalUsedFds Returns the number of used File Descriptors by
// reading it via /proc filesystem.
func GetTotalUsedFds() int {
if fds, err := ioutil.ReadDir(fmt.Sprintf("/proc/%d/fd", os.Getpid())); err != nil {
logrus.Errorf("Error opening /proc/%d/fd: %s", os.Getpid(), err)
} else {
return len(fds)
}
return -1
}

View File

@@ -0,0 +1,7 @@
package fileutils
// GetTotalUsedFds Returns the number of used File Descriptors. Not supported
// on Windows.
func GetTotalUsedFds() int {
return -1
}

View File

@@ -0,0 +1,100 @@
package gitutils
import (
"fmt"
"io/ioutil"
"net/http"
"net/url"
"os"
"os/exec"
"path/filepath"
"strings"
"github.com/hyperhq/hypercli/pkg/symlink"
"github.com/hyperhq/hypercli/pkg/urlutil"
)
// Clone clones a repository into a newly created directory which
// will be under "docker-build-git"
func Clone(remoteURL string) (string, error) {
if !urlutil.IsGitTransport(remoteURL) {
remoteURL = "https://" + remoteURL
}
root, err := ioutil.TempDir("", "docker-build-git")
if err != nil {
return "", err
}
u, err := url.Parse(remoteURL)
if err != nil {
return "", err
}
fragment := u.Fragment
clone := cloneArgs(u, root)
if output, err := git(clone...); err != nil {
return "", fmt.Errorf("Error trying to use git: %s (%s)", err, output)
}
return checkoutGit(fragment, root)
}
func cloneArgs(remoteURL *url.URL, root string) []string {
args := []string{"clone", "--recursive"}
shallow := len(remoteURL.Fragment) == 0
if shallow && strings.HasPrefix(remoteURL.Scheme, "http") {
res, err := http.Head(fmt.Sprintf("%s/info/refs?service=git-upload-pack", remoteURL))
if err != nil || res.Header.Get("Content-Type") != "application/x-git-upload-pack-advertisement" {
shallow = false
}
}
if shallow {
args = append(args, "--depth", "1")
}
if remoteURL.Fragment != "" {
remoteURL.Fragment = ""
}
return append(args, remoteURL.String(), root)
}
func checkoutGit(fragment, root string) (string, error) {
refAndDir := strings.SplitN(fragment, ":", 2)
if len(refAndDir[0]) != 0 {
if output, err := gitWithinDir(root, "checkout", refAndDir[0]); err != nil {
return "", fmt.Errorf("Error trying to use git: %s (%s)", err, output)
}
}
if len(refAndDir) > 1 && len(refAndDir[1]) != 0 {
newCtx, err := symlink.FollowSymlinkInScope(filepath.Join(root, refAndDir[1]), root)
if err != nil {
return "", fmt.Errorf("Error setting git context, %q not within git root: %s", refAndDir[1], err)
}
fi, err := os.Stat(newCtx)
if err != nil {
return "", err
}
if !fi.IsDir() {
return "", fmt.Errorf("Error setting git context, not a directory: %s", newCtx)
}
root = newCtx
}
return root, nil
}
func gitWithinDir(dir string, args ...string) ([]byte, error) {
a := []string{"--work-tree", dir, "--git-dir", filepath.Join(dir, ".git")}
return git(append(a, args...)...)
}
func git(args ...string) ([]byte, error) {
return exec.Command("git", args...).CombinedOutput()
}

View File

@@ -0,0 +1,186 @@
package gitutils
import (
"fmt"
"io/ioutil"
"net/http"
"net/http/httptest"
"net/url"
"os"
"path/filepath"
"reflect"
"testing"
)
func TestCloneArgsSmartHttp(t *testing.T) {
mux := http.NewServeMux()
server := httptest.NewServer(mux)
serverURL, _ := url.Parse(server.URL)
serverURL.Path = "/repo.git"
gitURL := serverURL.String()
mux.HandleFunc("/repo.git/info/refs", func(w http.ResponseWriter, r *http.Request) {
q := r.URL.Query().Get("service")
w.Header().Set("Content-Type", fmt.Sprintf("application/x-%s-advertisement", q))
})
args := cloneArgs(serverURL, "/tmp")
exp := []string{"clone", "--recursive", "--depth", "1", gitURL, "/tmp"}
if !reflect.DeepEqual(args, exp) {
t.Fatalf("Expected %v, got %v", exp, args)
}
}
func TestCloneArgsDumbHttp(t *testing.T) {
mux := http.NewServeMux()
server := httptest.NewServer(mux)
serverURL, _ := url.Parse(server.URL)
serverURL.Path = "/repo.git"
gitURL := serverURL.String()
mux.HandleFunc("/repo.git/info/refs", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/plain")
})
args := cloneArgs(serverURL, "/tmp")
exp := []string{"clone", "--recursive", gitURL, "/tmp"}
if !reflect.DeepEqual(args, exp) {
t.Fatalf("Expected %v, got %v", exp, args)
}
}
func TestCloneArgsGit(t *testing.T) {
u, _ := url.Parse("git://github.com/hyperhq/hypercli")
args := cloneArgs(u, "/tmp")
exp := []string{"clone", "--recursive", "--depth", "1", "git://github.com/hyperhq/hypercli", "/tmp"}
if !reflect.DeepEqual(args, exp) {
t.Fatalf("Expected %v, got %v", exp, args)
}
}
func TestCloneArgsStripFragment(t *testing.T) {
u, _ := url.Parse("git://github.com/hyperhq/hypercli#test")
args := cloneArgs(u, "/tmp")
exp := []string{"clone", "--recursive", "git://github.com/hyperhq/hypercli", "/tmp"}
if !reflect.DeepEqual(args, exp) {
t.Fatalf("Expected %v, got %v", exp, args)
}
}
func TestCheckoutGit(t *testing.T) {
root, err := ioutil.TempDir("", "docker-build-git-checkout")
if err != nil {
t.Fatal(err)
}
defer os.RemoveAll(root)
gitDir := filepath.Join(root, "repo")
_, err = git("init", gitDir)
if err != nil {
t.Fatal(err)
}
if _, err = gitWithinDir(gitDir, "config", "user.email", "test@docker.com"); err != nil {
t.Fatal(err)
}
if _, err = gitWithinDir(gitDir, "config", "user.name", "Docker test"); err != nil {
t.Fatal(err)
}
if err = ioutil.WriteFile(filepath.Join(gitDir, "Dockerfile"), []byte("FROM scratch"), 0644); err != nil {
t.Fatal(err)
}
subDir := filepath.Join(gitDir, "subdir")
if err = os.Mkdir(subDir, 0755); err != nil {
t.Fatal(err)
}
if err = ioutil.WriteFile(filepath.Join(subDir, "Dockerfile"), []byte("FROM scratch\nEXPOSE 5000"), 0644); err != nil {
t.Fatal(err)
}
if err = os.Symlink("../subdir", filepath.Join(gitDir, "parentlink")); err != nil {
t.Fatal(err)
}
if err = os.Symlink("/subdir", filepath.Join(gitDir, "absolutelink")); err != nil {
t.Fatal(err)
}
if _, err = gitWithinDir(gitDir, "add", "-A"); err != nil {
t.Fatal(err)
}
if _, err = gitWithinDir(gitDir, "commit", "-am", "First commit"); err != nil {
t.Fatal(err)
}
if _, err = gitWithinDir(gitDir, "checkout", "-b", "test"); err != nil {
t.Fatal(err)
}
if err = ioutil.WriteFile(filepath.Join(gitDir, "Dockerfile"), []byte("FROM scratch\nEXPOSE 3000"), 0644); err != nil {
t.Fatal(err)
}
if err = ioutil.WriteFile(filepath.Join(subDir, "Dockerfile"), []byte("FROM busybox\nEXPOSE 5000"), 0644); err != nil {
t.Fatal(err)
}
if _, err = gitWithinDir(gitDir, "add", "-A"); err != nil {
t.Fatal(err)
}
if _, err = gitWithinDir(gitDir, "commit", "-am", "Branch commit"); err != nil {
t.Fatal(err)
}
if _, err = gitWithinDir(gitDir, "checkout", "master"); err != nil {
t.Fatal(err)
}
cases := []struct {
frag string
exp string
fail bool
}{
{"", "FROM scratch", false},
{"master", "FROM scratch", false},
{":subdir", "FROM scratch\nEXPOSE 5000", false},
{":nosubdir", "", true}, // missing directory error
{":Dockerfile", "", true}, // not a directory error
{"master:nosubdir", "", true},
{"master:subdir", "FROM scratch\nEXPOSE 5000", false},
{"master:parentlink", "FROM scratch\nEXPOSE 5000", false},
{"master:absolutelink", "FROM scratch\nEXPOSE 5000", false},
{"master:../subdir", "", true},
{"test", "FROM scratch\nEXPOSE 3000", false},
{"test:", "FROM scratch\nEXPOSE 3000", false},
{"test:subdir", "FROM busybox\nEXPOSE 5000", false},
}
for _, c := range cases {
r, err := checkoutGit(c.frag, gitDir)
fail := err != nil
if fail != c.fail {
t.Fatalf("Expected %v failure, error was %v\n", c.fail, err)
}
if c.fail {
continue
}
b, err := ioutil.ReadFile(filepath.Join(r, "Dockerfile"))
if err != nil {
t.Fatal(err)
}
if string(b) != c.exp {
t.Fatalf("Expected %v, was %v\n", c.exp, string(b))
}
}
}

View File

@@ -0,0 +1,15 @@
// +build cgo
package graphdb
import "database/sql"
// NewSqliteConn opens a connection to a sqlite
// database.
func NewSqliteConn(root string) (*Database, error) {
conn, err := sql.Open("sqlite3", root)
if err != nil {
return nil, err
}
return NewDatabase(conn)
}

View File

@@ -0,0 +1,7 @@
// +build cgo,!windows
package graphdb
import (
_ "github.com/mattn/go-sqlite3" // registers sqlite
)

View File

@@ -0,0 +1,7 @@
// +build cgo,windows
package graphdb
import (
_ "github.com/mattn/go-sqlite3" // registers sqlite
)

View File

@@ -0,0 +1,8 @@
// +build !cgo
package graphdb
// NewSqliteConn return a new sqlite connection.
func NewSqliteConn(root string) (*Database, error) {
panic("Not implemented")
}

View File

@@ -0,0 +1,551 @@
package graphdb
import (
"database/sql"
"fmt"
"path"
"strings"
"sync"
)
const (
createEntityTable = `
CREATE TABLE IF NOT EXISTS entity (
id text NOT NULL PRIMARY KEY
);`
createEdgeTable = `
CREATE TABLE IF NOT EXISTS edge (
"entity_id" text NOT NULL,
"parent_id" text NULL,
"name" text NOT NULL,
CONSTRAINT "parent_fk" FOREIGN KEY ("parent_id") REFERENCES "entity" ("id"),
CONSTRAINT "entity_fk" FOREIGN KEY ("entity_id") REFERENCES "entity" ("id")
);
`
createEdgeIndices = `
CREATE UNIQUE INDEX IF NOT EXISTS "name_parent_ix" ON "edge" (parent_id, name);
`
)
// Entity with a unique id.
type Entity struct {
id string
}
// An Edge connects two entities together.
type Edge struct {
EntityID string
Name string
ParentID string
}
// Entities stores the list of entities.
type Entities map[string]*Entity
// Edges stores the relationships between entities.
type Edges []*Edge
// WalkFunc is a function invoked to process an individual entity.
type WalkFunc func(fullPath string, entity *Entity) error
// Database is a graph database for storing entities and their relationships.
type Database struct {
conn *sql.DB
mux sync.RWMutex
}
// IsNonUniqueNameError processes the error to check if it's caused by
// a constraint violation.
// This is necessary because the error isn't the same across various
// sqlite versions.
func IsNonUniqueNameError(err error) bool {
str := err.Error()
// sqlite 3.7.17-1ubuntu1 returns:
// Set failure: Abort due to constraint violation: columns parent_id, name are not unique
if strings.HasSuffix(str, "name are not unique") {
return true
}
// sqlite-3.8.3-1.fc20 returns:
// Set failure: Abort due to constraint violation: UNIQUE constraint failed: edge.parent_id, edge.name
if strings.Contains(str, "UNIQUE constraint failed") && strings.Contains(str, "edge.name") {
return true
}
// sqlite-3.6.20-1.el6 returns:
// Set failure: Abort due to constraint violation: constraint failed
if strings.HasSuffix(str, "constraint failed") {
return true
}
return false
}
// NewDatabase creates a new graph database initialized with a root entity.
func NewDatabase(conn *sql.DB) (*Database, error) {
if conn == nil {
return nil, fmt.Errorf("Database connection cannot be nil")
}
db := &Database{conn: conn}
// Create root entities
tx, err := conn.Begin()
if err != nil {
return nil, err
}
if _, err := tx.Exec(createEntityTable); err != nil {
return nil, err
}
if _, err := tx.Exec(createEdgeTable); err != nil {
return nil, err
}
if _, err := tx.Exec(createEdgeIndices); err != nil {
return nil, err
}
if _, err := tx.Exec("DELETE FROM entity where id = ?", "0"); err != nil {
tx.Rollback()
return nil, err
}
if _, err := tx.Exec("INSERT INTO entity (id) VALUES (?);", "0"); err != nil {
tx.Rollback()
return nil, err
}
if _, err := tx.Exec("DELETE FROM edge where entity_id=? and name=?", "0", "/"); err != nil {
tx.Rollback()
return nil, err
}
if _, err := tx.Exec("INSERT INTO edge (entity_id, name) VALUES(?,?);", "0", "/"); err != nil {
tx.Rollback()
return nil, err
}
if err := tx.Commit(); err != nil {
return nil, err
}
return db, nil
}
// Close the underlying connection to the database.
func (db *Database) Close() error {
return db.conn.Close()
}
// Set the entity id for a given path.
func (db *Database) Set(fullPath, id string) (*Entity, error) {
db.mux.Lock()
defer db.mux.Unlock()
tx, err := db.conn.Begin()
if err != nil {
return nil, err
}
var entityID string
if err := tx.QueryRow("SELECT id FROM entity WHERE id = ?;", id).Scan(&entityID); err != nil {
if err == sql.ErrNoRows {
if _, err := tx.Exec("INSERT INTO entity (id) VALUES(?);", id); err != nil {
tx.Rollback()
return nil, err
}
} else {
tx.Rollback()
return nil, err
}
}
e := &Entity{id}
parentPath, name := splitPath(fullPath)
if err := db.setEdge(parentPath, name, e, tx); err != nil {
tx.Rollback()
return nil, err
}
if err := tx.Commit(); err != nil {
return nil, err
}
return e, nil
}
// Exists returns true if a name already exists in the database.
func (db *Database) Exists(name string) bool {
db.mux.RLock()
defer db.mux.RUnlock()
e, err := db.get(name)
if err != nil {
return false
}
return e != nil
}
func (db *Database) setEdge(parentPath, name string, e *Entity, tx *sql.Tx) error {
parent, err := db.get(parentPath)
if err != nil {
return err
}
if parent.id == e.id {
return fmt.Errorf("Cannot set self as child")
}
if _, err := tx.Exec("INSERT INTO edge (parent_id, name, entity_id) VALUES (?,?,?);", parent.id, name, e.id); err != nil {
return err
}
return nil
}
// RootEntity returns the root "/" entity for the database.
func (db *Database) RootEntity() *Entity {
return &Entity{
id: "0",
}
}
// Get returns the entity for a given path.
func (db *Database) Get(name string) *Entity {
db.mux.RLock()
defer db.mux.RUnlock()
e, err := db.get(name)
if err != nil {
return nil
}
return e
}
func (db *Database) get(name string) (*Entity, error) {
e := db.RootEntity()
// We always know the root name so return it if
// it is requested
if name == "/" {
return e, nil
}
parts := split(name)
for i := 1; i < len(parts); i++ {
p := parts[i]
if p == "" {
continue
}
next := db.child(e, p)
if next == nil {
return nil, fmt.Errorf("Cannot find child for %s", name)
}
e = next
}
return e, nil
}
// List all entities by from the name.
// The key will be the full path of the entity.
func (db *Database) List(name string, depth int) Entities {
db.mux.RLock()
defer db.mux.RUnlock()
out := Entities{}
e, err := db.get(name)
if err != nil {
return out
}
children, err := db.children(e, name, depth, nil)
if err != nil {
return out
}
for _, c := range children {
out[c.FullPath] = c.Entity
}
return out
}
// Walk through the child graph of an entity, calling walkFunc for each child entity.
// It is safe for walkFunc to call graph functions.
func (db *Database) Walk(name string, walkFunc WalkFunc, depth int) error {
children, err := db.Children(name, depth)
if err != nil {
return err
}
// Note: the database lock must not be held while calling walkFunc
for _, c := range children {
if err := walkFunc(c.FullPath, c.Entity); err != nil {
return err
}
}
return nil
}
// Children returns the children of the specified entity.
func (db *Database) Children(name string, depth int) ([]WalkMeta, error) {
db.mux.RLock()
defer db.mux.RUnlock()
e, err := db.get(name)
if err != nil {
return nil, err
}
return db.children(e, name, depth, nil)
}
// Parents returns the parents of a specified entity.
func (db *Database) Parents(name string) ([]string, error) {
db.mux.RLock()
defer db.mux.RUnlock()
e, err := db.get(name)
if err != nil {
return nil, err
}
return db.parents(e)
}
// Refs returns the reference count for a specified id.
func (db *Database) Refs(id string) int {
db.mux.RLock()
defer db.mux.RUnlock()
var count int
if err := db.conn.QueryRow("SELECT COUNT(*) FROM edge WHERE entity_id = ?;", id).Scan(&count); err != nil {
return 0
}
return count
}
// RefPaths returns all the id's path references.
func (db *Database) RefPaths(id string) Edges {
db.mux.RLock()
defer db.mux.RUnlock()
refs := Edges{}
rows, err := db.conn.Query("SELECT name, parent_id FROM edge WHERE entity_id = ?;", id)
if err != nil {
return refs
}
defer rows.Close()
for rows.Next() {
var name string
var parentID string
if err := rows.Scan(&name, &parentID); err != nil {
return refs
}
refs = append(refs, &Edge{
EntityID: id,
Name: name,
ParentID: parentID,
})
}
return refs
}
// Delete the reference to an entity at a given path.
func (db *Database) Delete(name string) error {
db.mux.Lock()
defer db.mux.Unlock()
if name == "/" {
return fmt.Errorf("Cannot delete root entity")
}
parentPath, n := splitPath(name)
parent, err := db.get(parentPath)
if err != nil {
return err
}
if _, err := db.conn.Exec("DELETE FROM edge WHERE parent_id = ? AND name = ?;", parent.id, n); err != nil {
return err
}
return nil
}
// Purge removes the entity with the specified id
// Walk the graph to make sure all references to the entity
// are removed and return the number of references removed
func (db *Database) Purge(id string) (int, error) {
db.mux.Lock()
defer db.mux.Unlock()
tx, err := db.conn.Begin()
if err != nil {
return -1, err
}
// Delete all edges
rows, err := tx.Exec("DELETE FROM edge WHERE entity_id = ?;", id)
if err != nil {
tx.Rollback()
return -1, err
}
changes, err := rows.RowsAffected()
if err != nil {
return -1, err
}
// Clear who's using this id as parent
refs, err := tx.Exec("DELETE FROM edge WHERE parent_id = ?;", id)
if err != nil {
tx.Rollback()
return -1, err
}
refsCount, err := refs.RowsAffected()
if err != nil {
return -1, err
}
// Delete entity
if _, err := tx.Exec("DELETE FROM entity where id = ?;", id); err != nil {
tx.Rollback()
return -1, err
}
if err := tx.Commit(); err != nil {
return -1, err
}
return int(changes + refsCount), nil
}
// Rename an edge for a given path
func (db *Database) Rename(currentName, newName string) error {
db.mux.Lock()
defer db.mux.Unlock()
parentPath, name := splitPath(currentName)
newParentPath, newEdgeName := splitPath(newName)
if parentPath != newParentPath {
return fmt.Errorf("Cannot rename when root paths do not match %s != %s", parentPath, newParentPath)
}
parent, err := db.get(parentPath)
if err != nil {
return err
}
rows, err := db.conn.Exec("UPDATE edge SET name = ? WHERE parent_id = ? AND name = ?;", newEdgeName, parent.id, name)
if err != nil {
return err
}
i, err := rows.RowsAffected()
if err != nil {
return err
}
if i == 0 {
return fmt.Errorf("Cannot locate edge for %s %s", parent.id, name)
}
return nil
}
// WalkMeta stores the walk metadata.
type WalkMeta struct {
Parent *Entity
Entity *Entity
FullPath string
Edge *Edge
}
func (db *Database) children(e *Entity, name string, depth int, entities []WalkMeta) ([]WalkMeta, error) {
if e == nil {
return entities, nil
}
rows, err := db.conn.Query("SELECT entity_id, name FROM edge where parent_id = ?;", e.id)
if err != nil {
return nil, err
}
defer rows.Close()
for rows.Next() {
var entityID, entityName string
if err := rows.Scan(&entityID, &entityName); err != nil {
return nil, err
}
child := &Entity{entityID}
edge := &Edge{
ParentID: e.id,
Name: entityName,
EntityID: child.id,
}
meta := WalkMeta{
Parent: e,
Entity: child,
FullPath: path.Join(name, edge.Name),
Edge: edge,
}
entities = append(entities, meta)
if depth != 0 {
nDepth := depth
if depth != -1 {
nDepth--
}
entities, err = db.children(child, meta.FullPath, nDepth, entities)
if err != nil {
return nil, err
}
}
}
return entities, nil
}
func (db *Database) parents(e *Entity) (parents []string, err error) {
if e == nil {
return parents, nil
}
rows, err := db.conn.Query("SELECT parent_id FROM edge where entity_id = ?;", e.id)
if err != nil {
return nil, err
}
defer rows.Close()
for rows.Next() {
var parentID string
if err := rows.Scan(&parentID); err != nil {
return nil, err
}
parents = append(parents, parentID)
}
return parents, nil
}
// Return the entity based on the parent path and name.
func (db *Database) child(parent *Entity, name string) *Entity {
var id string
if err := db.conn.QueryRow("SELECT entity_id FROM edge WHERE parent_id = ? AND name = ?;", parent.id, name).Scan(&id); err != nil {
return nil
}
return &Entity{id}
}
// ID returns the id used to reference this entity.
func (e *Entity) ID() string {
return e.id
}
// Paths returns the paths sorted by depth.
func (e Entities) Paths() []string {
out := make([]string, len(e))
var i int
for k := range e {
out[i] = k
i++
}
sortByDepth(out)
return out
}

View File

@@ -0,0 +1,657 @@
package graphdb
import (
"database/sql"
"fmt"
"os"
"path"
"strconv"
"testing"
_ "github.com/mattn/go-sqlite3"
)
func newTestDb(t *testing.T) (*Database, string) {
p := path.Join(os.TempDir(), "sqlite.db")
conn, err := sql.Open("sqlite3", p)
db, err := NewDatabase(conn)
if err != nil {
t.Fatal(err)
}
return db, p
}
func destroyTestDb(dbPath string) {
os.Remove(dbPath)
}
func TestNewDatabase(t *testing.T) {
db, dbpath := newTestDb(t)
if db == nil {
t.Fatal("Database should not be nil")
}
db.Close()
defer destroyTestDb(dbpath)
}
func TestCreateRootEntity(t *testing.T) {
db, dbpath := newTestDb(t)
defer destroyTestDb(dbpath)
root := db.RootEntity()
if root == nil {
t.Fatal("Root entity should not be nil")
}
}
func TestGetRootEntity(t *testing.T) {
db, dbpath := newTestDb(t)
defer destroyTestDb(dbpath)
e := db.Get("/")
if e == nil {
t.Fatal("Entity should not be nil")
}
if e.ID() != "0" {
t.Fatalf("Entity id should be 0, got %s", e.ID())
}
}
func TestSetEntityWithDifferentName(t *testing.T) {
db, dbpath := newTestDb(t)
defer destroyTestDb(dbpath)
db.Set("/test", "1")
if _, err := db.Set("/other", "1"); err != nil {
t.Fatal(err)
}
}
func TestSetDuplicateEntity(t *testing.T) {
db, dbpath := newTestDb(t)
defer destroyTestDb(dbpath)
if _, err := db.Set("/foo", "42"); err != nil {
t.Fatal(err)
}
if _, err := db.Set("/foo", "43"); err == nil {
t.Fatalf("Creating an entry with a duplicate path did not cause an error")
}
}
func TestCreateChild(t *testing.T) {
db, dbpath := newTestDb(t)
defer destroyTestDb(dbpath)
child, err := db.Set("/db", "1")
if err != nil {
t.Fatal(err)
}
if child == nil {
t.Fatal("Child should not be nil")
}
if child.ID() != "1" {
t.Fail()
}
}
func TestParents(t *testing.T) {
db, dbpath := newTestDb(t)
defer destroyTestDb(dbpath)
for i := 1; i < 6; i++ {
a := strconv.Itoa(i)
if _, err := db.Set("/"+a, a); err != nil {
t.Fatal(err)
}
}
for i := 6; i < 11; i++ {
a := strconv.Itoa(i)
p := strconv.Itoa(i - 5)
key := fmt.Sprintf("/%s/%s", p, a)
if _, err := db.Set(key, a); err != nil {
t.Fatal(err)
}
parents, err := db.Parents(key)
if err != nil {
t.Fatal(err)
}
if len(parents) != 1 {
t.Fatalf("Expected 1 entry for %s got %d", key, len(parents))
}
if parents[0] != p {
t.Fatalf("ID %s received, %s expected", parents[0], p)
}
}
}
func TestChildren(t *testing.T) {
db, dbpath := newTestDb(t)
defer destroyTestDb(dbpath)
str := "/"
for i := 1; i < 6; i++ {
a := strconv.Itoa(i)
if _, err := db.Set(str+a, a); err != nil {
t.Fatal(err)
}
str = str + a + "/"
}
str = "/"
for i := 10; i < 30; i++ { // 20 entities
a := strconv.Itoa(i)
if _, err := db.Set(str+a, a); err != nil {
t.Fatal(err)
}
str = str + a + "/"
}
entries, err := db.Children("/", 5)
if err != nil {
t.Fatal(err)
}
if len(entries) != 11 {
t.Fatalf("Expect 11 entries for / got %d", len(entries))
}
entries, err = db.Children("/", 20)
if err != nil {
t.Fatal(err)
}
if len(entries) != 25 {
t.Fatalf("Expect 25 entries for / got %d", len(entries))
}
}
func TestListAllRootChildren(t *testing.T) {
db, dbpath := newTestDb(t)
defer destroyTestDb(dbpath)
for i := 1; i < 6; i++ {
a := strconv.Itoa(i)
if _, err := db.Set("/"+a, a); err != nil {
t.Fatal(err)
}
}
entries := db.List("/", -1)
if len(entries) != 5 {
t.Fatalf("Expect 5 entries for / got %d", len(entries))
}
}
func TestListAllSubChildren(t *testing.T) {
db, dbpath := newTestDb(t)
defer destroyTestDb(dbpath)
_, err := db.Set("/webapp", "1")
if err != nil {
t.Fatal(err)
}
child2, err := db.Set("/db", "2")
if err != nil {
t.Fatal(err)
}
child4, err := db.Set("/logs", "4")
if err != nil {
t.Fatal(err)
}
if _, err := db.Set("/db/logs", child4.ID()); err != nil {
t.Fatal(err)
}
child3, err := db.Set("/sentry", "3")
if err != nil {
t.Fatal(err)
}
if _, err := db.Set("/webapp/sentry", child3.ID()); err != nil {
t.Fatal(err)
}
if _, err := db.Set("/webapp/db", child2.ID()); err != nil {
t.Fatal(err)
}
entries := db.List("/webapp", 1)
if len(entries) != 3 {
t.Fatalf("Expect 3 entries for / got %d", len(entries))
}
entries = db.List("/webapp", 0)
if len(entries) != 2 {
t.Fatalf("Expect 2 entries for / got %d", len(entries))
}
}
func TestAddSelfAsChild(t *testing.T) {
db, dbpath := newTestDb(t)
defer destroyTestDb(dbpath)
child, err := db.Set("/test", "1")
if err != nil {
t.Fatal(err)
}
if _, err := db.Set("/test/other", child.ID()); err == nil {
t.Fatal("Error should not be nil")
}
}
func TestAddChildToNonExistentRoot(t *testing.T) {
db, dbpath := newTestDb(t)
defer destroyTestDb(dbpath)
if _, err := db.Set("/myapp", "1"); err != nil {
t.Fatal(err)
}
if _, err := db.Set("/myapp/proxy/db", "2"); err == nil {
t.Fatal("Error should not be nil")
}
}
func TestWalkAll(t *testing.T) {
db, dbpath := newTestDb(t)
defer destroyTestDb(dbpath)
_, err := db.Set("/webapp", "1")
if err != nil {
t.Fatal(err)
}
child2, err := db.Set("/db", "2")
if err != nil {
t.Fatal(err)
}
child4, err := db.Set("/db/logs", "4")
if err != nil {
t.Fatal(err)
}
if _, err := db.Set("/webapp/logs", child4.ID()); err != nil {
t.Fatal(err)
}
child3, err := db.Set("/sentry", "3")
if err != nil {
t.Fatal(err)
}
if _, err := db.Set("/webapp/sentry", child3.ID()); err != nil {
t.Fatal(err)
}
if _, err := db.Set("/webapp/db", child2.ID()); err != nil {
t.Fatal(err)
}
child5, err := db.Set("/gograph", "5")
if err != nil {
t.Fatal(err)
}
if _, err := db.Set("/webapp/same-ref-diff-name", child5.ID()); err != nil {
t.Fatal(err)
}
if err := db.Walk("/", func(p string, e *Entity) error {
t.Logf("Path: %s Entity: %s", p, e.ID())
return nil
}, -1); err != nil {
t.Fatal(err)
}
}
func TestGetEntityByPath(t *testing.T) {
db, dbpath := newTestDb(t)
defer destroyTestDb(dbpath)
_, err := db.Set("/webapp", "1")
if err != nil {
t.Fatal(err)
}
child2, err := db.Set("/db", "2")
if err != nil {
t.Fatal(err)
}
child4, err := db.Set("/logs", "4")
if err != nil {
t.Fatal(err)
}
if _, err := db.Set("/db/logs", child4.ID()); err != nil {
t.Fatal(err)
}
child3, err := db.Set("/sentry", "3")
if err != nil {
t.Fatal(err)
}
if _, err := db.Set("/webapp/sentry", child3.ID()); err != nil {
t.Fatal(err)
}
if _, err := db.Set("/webapp/db", child2.ID()); err != nil {
t.Fatal(err)
}
child5, err := db.Set("/gograph", "5")
if err != nil {
t.Fatal(err)
}
if _, err := db.Set("/webapp/same-ref-diff-name", child5.ID()); err != nil {
t.Fatal(err)
}
entity := db.Get("/webapp/db/logs")
if entity == nil {
t.Fatal("Entity should not be nil")
}
if entity.ID() != "4" {
t.Fatalf("Expected to get entity with id 4, got %s", entity.ID())
}
}
func TestEnitiesPaths(t *testing.T) {
db, dbpath := newTestDb(t)
defer destroyTestDb(dbpath)
_, err := db.Set("/webapp", "1")
if err != nil {
t.Fatal(err)
}
child2, err := db.Set("/db", "2")
if err != nil {
t.Fatal(err)
}
child4, err := db.Set("/logs", "4")
if err != nil {
t.Fatal(err)
}
if _, err := db.Set("/db/logs", child4.ID()); err != nil {
t.Fatal(err)
}
child3, err := db.Set("/sentry", "3")
if err != nil {
t.Fatal(err)
}
if _, err := db.Set("/webapp/sentry", child3.ID()); err != nil {
t.Fatal(err)
}
if _, err := db.Set("/webapp/db", child2.ID()); err != nil {
t.Fatal(err)
}
child5, err := db.Set("/gograph", "5")
if err != nil {
t.Fatal(err)
}
if _, err := db.Set("/webapp/same-ref-diff-name", child5.ID()); err != nil {
t.Fatal(err)
}
out := db.List("/", -1)
for _, p := range out.Paths() {
t.Log(p)
}
}
func TestDeleteRootEntity(t *testing.T) {
db, dbpath := newTestDb(t)
defer destroyTestDb(dbpath)
if err := db.Delete("/"); err == nil {
t.Fatal("Error should not be nil")
}
}
func TestDeleteEntity(t *testing.T) {
db, dbpath := newTestDb(t)
defer destroyTestDb(dbpath)
_, err := db.Set("/webapp", "1")
if err != nil {
t.Fatal(err)
}
child2, err := db.Set("/db", "2")
if err != nil {
t.Fatal(err)
}
child4, err := db.Set("/logs", "4")
if err != nil {
t.Fatal(err)
}
if _, err := db.Set("/db/logs", child4.ID()); err != nil {
t.Fatal(err)
}
child3, err := db.Set("/sentry", "3")
if err != nil {
t.Fatal(err)
}
if _, err := db.Set("/webapp/sentry", child3.ID()); err != nil {
t.Fatal(err)
}
if _, err := db.Set("/webapp/db", child2.ID()); err != nil {
t.Fatal(err)
}
child5, err := db.Set("/gograph", "5")
if err != nil {
t.Fatal(err)
}
if _, err := db.Set("/webapp/same-ref-diff-name", child5.ID()); err != nil {
t.Fatal(err)
}
if err := db.Delete("/webapp/sentry"); err != nil {
t.Fatal(err)
}
entity := db.Get("/webapp/sentry")
if entity != nil {
t.Fatal("Entity /webapp/sentry should be nil")
}
}
func TestCountRefs(t *testing.T) {
db, dbpath := newTestDb(t)
defer destroyTestDb(dbpath)
db.Set("/webapp", "1")
if db.Refs("1") != 1 {
t.Fatal("Expect reference count to be 1")
}
db.Set("/db", "2")
db.Set("/webapp/db", "2")
if db.Refs("2") != 2 {
t.Fatal("Expect reference count to be 2")
}
}
func TestPurgeId(t *testing.T) {
db, dbpath := newTestDb(t)
defer destroyTestDb(dbpath)
db.Set("/webapp", "1")
if c := db.Refs("1"); c != 1 {
t.Fatalf("Expect reference count to be 1, got %d", c)
}
db.Set("/db", "2")
db.Set("/webapp/db", "2")
count, err := db.Purge("2")
if err != nil {
t.Fatal(err)
}
if count != 2 {
t.Fatalf("Expected 2 references to be removed, got %d", count)
}
}
// Regression test https://github.com/hyperhq/hypercli/issues/12334
func TestPurgeIdRefPaths(t *testing.T) {
db, dbpath := newTestDb(t)
defer destroyTestDb(dbpath)
db.Set("/webapp", "1")
db.Set("/db", "2")
db.Set("/db/webapp", "1")
if c := db.Refs("1"); c != 2 {
t.Fatalf("Expected 2 reference for webapp, got %d", c)
}
if c := db.Refs("2"); c != 1 {
t.Fatalf("Expected 1 reference for db, got %d", c)
}
if rp := db.RefPaths("2"); len(rp) != 1 {
t.Fatalf("Expected 1 reference path for db, got %d", len(rp))
}
count, err := db.Purge("2")
if err != nil {
t.Fatal(err)
}
if count != 2 {
t.Fatalf("Expected 2 rows to be removed, got %d", count)
}
if c := db.Refs("2"); c != 0 {
t.Fatalf("Expected 0 reference for db, got %d", c)
}
if c := db.Refs("1"); c != 1 {
t.Fatalf("Expected 1 reference for webapp, got %d", c)
}
}
func TestRename(t *testing.T) {
db, dbpath := newTestDb(t)
defer destroyTestDb(dbpath)
db.Set("/webapp", "1")
if db.Refs("1") != 1 {
t.Fatal("Expect reference count to be 1")
}
db.Set("/db", "2")
db.Set("/webapp/db", "2")
if db.Get("/webapp/db") == nil {
t.Fatal("Cannot find entity at path /webapp/db")
}
if err := db.Rename("/webapp/db", "/webapp/newdb"); err != nil {
t.Fatal(err)
}
if db.Get("/webapp/db") != nil {
t.Fatal("Entity should not exist at /webapp/db")
}
if db.Get("/webapp/newdb") == nil {
t.Fatal("Cannot find entity at path /webapp/newdb")
}
}
func TestCreateMultipleNames(t *testing.T) {
db, dbpath := newTestDb(t)
defer destroyTestDb(dbpath)
db.Set("/db", "1")
if _, err := db.Set("/myapp", "1"); err != nil {
t.Fatal(err)
}
db.Walk("/", func(p string, e *Entity) error {
t.Logf("%s\n", p)
return nil
}, -1)
}
func TestRefPaths(t *testing.T) {
db, dbpath := newTestDb(t)
defer destroyTestDb(dbpath)
db.Set("/webapp", "1")
db.Set("/db", "2")
db.Set("/webapp/db", "2")
refs := db.RefPaths("2")
if len(refs) != 2 {
t.Fatalf("Expected reference count to be 2, got %d", len(refs))
}
}
func TestExistsTrue(t *testing.T) {
db, dbpath := newTestDb(t)
defer destroyTestDb(dbpath)
db.Set("/testing", "1")
if !db.Exists("/testing") {
t.Fatalf("/tesing should exist")
}
}
func TestExistsFalse(t *testing.T) {
db, dbpath := newTestDb(t)
defer destroyTestDb(dbpath)
db.Set("/toerhe", "1")
if db.Exists("/testing") {
t.Fatalf("/tesing should not exist")
}
}
func TestGetNameWithTrailingSlash(t *testing.T) {
db, dbpath := newTestDb(t)
defer destroyTestDb(dbpath)
db.Set("/todo", "1")
e := db.Get("/todo/")
if e == nil {
t.Fatalf("Entity should not be nil")
}
}
func TestConcurrentWrites(t *testing.T) {
db, dbpath := newTestDb(t)
defer destroyTestDb(dbpath)
errs := make(chan error, 2)
save := func(name string, id string) {
if _, err := db.Set(fmt.Sprintf("/%s", name), id); err != nil {
errs <- err
}
errs <- nil
}
purge := func(id string) {
if _, err := db.Purge(id); err != nil {
errs <- err
}
errs <- nil
}
save("/1", "1")
go purge("1")
go save("/2", "2")
any := false
for i := 0; i < 2; i++ {
if err := <-errs; err != nil {
any = true
t.Log(err)
}
}
if any {
t.Fail()
}
}

27
vendor/github.com/hyperhq/hypercli/pkg/graphdb/sort.go generated vendored Normal file
View File

@@ -0,0 +1,27 @@
package graphdb
import "sort"
type pathSorter struct {
paths []string
by func(i, j string) bool
}
func sortByDepth(paths []string) {
s := &pathSorter{paths, func(i, j string) bool {
return PathDepth(i) > PathDepth(j)
}}
sort.Sort(s)
}
func (s *pathSorter) Len() int {
return len(s.paths)
}
func (s *pathSorter) Swap(i, j int) {
s.paths[i], s.paths[j] = s.paths[j], s.paths[i]
}
func (s *pathSorter) Less(i, j int) bool {
return s.by(s.paths[i], s.paths[j])
}

View File

@@ -0,0 +1,29 @@
package graphdb
import (
"testing"
)
func TestSort(t *testing.T) {
paths := []string{
"/",
"/myreallylongname",
"/app/db",
}
sortByDepth(paths)
if len(paths) != 3 {
t.Fatalf("Expected 3 parts got %d", len(paths))
}
if paths[0] != "/app/db" {
t.Fatalf("Expected /app/db got %s", paths[0])
}
if paths[1] != "/myreallylongname" {
t.Fatalf("Expected /myreallylongname got %s", paths[1])
}
if paths[2] != "/" {
t.Fatalf("Expected / got %s", paths[2])
}
}

View File

@@ -0,0 +1,32 @@
package graphdb
import (
"path"
"strings"
)
// Split p on /
func split(p string) []string {
return strings.Split(p, "/")
}
// PathDepth returns the depth or number of / in a given path
func PathDepth(p string) int {
parts := split(p)
if len(parts) == 2 && parts[1] == "" {
return 1
}
return len(parts)
}
func splitPath(p string) (parent, name string) {
if p[0] != '/' {
p = "/" + p
}
parent, name = path.Split(p)
l := len(parent)
if parent[l-1] == '/' {
parent = parent[:l-1]
}
return
}

View File

@@ -0,0 +1,39 @@
package homedir
import (
"os"
"runtime"
"github.com/opencontainers/runc/libcontainer/user"
)
// Key returns the env var name for the user's home dir based on
// the platform being run on
func Key() string {
if runtime.GOOS == "windows" {
return "USERPROFILE"
}
return "HOME"
}
// Get returns the home directory of the current user with the help of
// environment variables depending on the target operating system.
// Returned path should be used with "path/filepath" to form new paths.
func Get() string {
home := os.Getenv(Key())
if home == "" && runtime.GOOS != "windows" {
if u, err := user.CurrentUser(); err == nil {
return u.Home
}
}
return home
}
// GetShortcutString returns the string that is shortcut to user's home directory
// in the native shell of the platform running on.
func GetShortcutString() string {
if runtime.GOOS == "windows" {
return "%USERPROFILE%" // be careful while using in format functions
}
return "~"
}

View File

@@ -0,0 +1,24 @@
package homedir
import (
"path/filepath"
"testing"
)
func TestGet(t *testing.T) {
home := Get()
if home == "" {
t.Fatal("returned home directory is empty")
}
if !filepath.IsAbs(home) {
t.Fatalf("returned path is not absolute: %s", home)
}
}
func TestGetShortcutString(t *testing.T) {
shortcut := GetShortcutString()
if shortcut == "" {
t.Fatal("returned shortcut string is empty")
}
}

View File

@@ -0,0 +1,56 @@
package httputils
import (
"errors"
"fmt"
"net/http"
"regexp"
"strings"
"github.com/hyperhq/hypercli/pkg/jsonmessage"
)
var (
headerRegexp = regexp.MustCompile(`^(?:(.+)/(.+?))\((.+)\).*$`)
errInvalidHeader = errors.New("Bad header, should be in format `docker/version (platform)`")
)
// Download requests a given URL and returns an io.Reader.
func Download(url string) (resp *http.Response, err error) {
if resp, err = http.Get(url); err != nil {
return nil, err
}
if resp.StatusCode >= 400 {
return nil, fmt.Errorf("Got HTTP status code >= 400: %s", resp.Status)
}
return resp, nil
}
// NewHTTPRequestError returns a JSON response error.
func NewHTTPRequestError(msg string, res *http.Response) error {
return &jsonmessage.JSONError{
Message: msg,
Code: res.StatusCode,
}
}
// ServerHeader contains the server information.
type ServerHeader struct {
App string // docker
Ver string // 1.8.0-dev
OS string // windows or linux
}
// ParseServerHeader extracts pieces from an HTTP server header
// which is in the format "docker/version (os)" eg docker/1.8.0-dev (windows).
func ParseServerHeader(hdr string) (*ServerHeader, error) {
matches := headerRegexp.FindStringSubmatch(hdr)
if len(matches) != 4 {
return nil, errInvalidHeader
}
return &ServerHeader{
App: strings.TrimSpace(matches[1]),
Ver: strings.TrimSpace(matches[2]),
OS: strings.TrimSpace(matches[3]),
}, nil
}

View File

@@ -0,0 +1,115 @@
package httputils
import (
"fmt"
"io/ioutil"
"net/http"
"net/http/httptest"
"strings"
"testing"
)
func TestDownload(t *testing.T) {
expected := "Hello, docker !"
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintf(w, expected)
}))
defer ts.Close()
response, err := Download(ts.URL)
if err != nil {
t.Fatal(err)
}
actual, err := ioutil.ReadAll(response.Body)
response.Body.Close()
if err != nil || string(actual) != expected {
t.Fatalf("Expected the response %q, got err:%v, response:%v, actual:%s", expected, err, response, string(actual))
}
}
func TestDownload400Errors(t *testing.T) {
expectedError := "Got HTTP status code >= 400: 403 Forbidden"
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// 403
http.Error(w, "something failed (forbidden)", http.StatusForbidden)
}))
defer ts.Close()
// Expected status code = 403
if _, err := Download(ts.URL); err == nil || err.Error() != expectedError {
t.Fatalf("Expected the the error %q, got %v", expectedError, err)
}
}
func TestDownloadOtherErrors(t *testing.T) {
if _, err := Download("I'm not an url.."); err == nil || !strings.Contains(err.Error(), "unsupported protocol scheme") {
t.Fatalf("Expected an error with 'unsupported protocol scheme', got %v", err)
}
}
func TestNewHTTPRequestError(t *testing.T) {
errorMessage := "Some error message"
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// 403
http.Error(w, errorMessage, http.StatusForbidden)
}))
defer ts.Close()
httpResponse, err := http.Get(ts.URL)
if err != nil {
t.Fatal(err)
}
if err := NewHTTPRequestError(errorMessage, httpResponse); err.Error() != errorMessage {
t.Fatalf("Expected err to be %q, got %v", errorMessage, err)
}
}
func TestParseServerHeader(t *testing.T) {
inputs := map[string][]string{
"bad header": {"error"},
"(bad header)": {"error"},
"(without/spaces)": {"error"},
"(header/with spaces)": {"error"},
"foo/bar (baz)": {"foo", "bar", "baz"},
"foo/bar": {"error"},
"foo": {"error"},
"foo/bar (baz space)": {"foo", "bar", "baz space"},
" f f / b b ( b s ) ": {"f f", "b b", "b s"},
"foo/bar (baz) ignore": {"foo", "bar", "baz"},
"foo/bar ()": {"error"},
"foo/bar()": {"error"},
"foo/bar(baz)": {"foo", "bar", "baz"},
"foo/bar/zzz(baz)": {"foo/bar", "zzz", "baz"},
"foo/bar(baz/abc)": {"foo", "bar", "baz/abc"},
"foo/bar(baz (abc))": {"foo", "bar", "baz (abc)"},
}
for header, values := range inputs {
serverHeader, err := ParseServerHeader(header)
if err != nil {
if err != errInvalidHeader {
t.Fatalf("Failed to parse %q, and got some unexpected error: %q", header, err)
}
if values[0] == "error" {
continue
}
t.Fatalf("Header %q failed to parse when it shouldn't have", header)
}
if values[0] == "error" {
t.Fatalf("Header %q parsed ok when it should have failed(%q).", header, serverHeader)
}
if serverHeader.App != values[0] {
t.Fatalf("Expected serverHeader.App for %q to equal %q, got %q", header, values[0], serverHeader.App)
}
if serverHeader.Ver != values[1] {
t.Fatalf("Expected serverHeader.Ver for %q to equal %q, got %q", header, values[1], serverHeader.Ver)
}
if serverHeader.OS != values[2] {
t.Fatalf("Expected serverHeader.OS for %q to equal %q, got %q", header, values[2], serverHeader.OS)
}
}
}

View File

@@ -0,0 +1,30 @@
package httputils
import (
"mime"
"net/http"
)
// MimeTypes stores the MIME content type.
var MimeTypes = struct {
TextPlain string
Tar string
OctetStream string
}{"text/plain", "application/tar", "application/octet-stream"}
// DetectContentType returns a best guess representation of the MIME
// content type for the bytes at c. The value detected by
// http.DetectContentType is guaranteed not be nil, defaulting to
// application/octet-stream when a better guess cannot be made. The
// result of this detection is then run through mime.ParseMediaType()
// which separates the actual MIME string from any parameters.
func DetectContentType(c []byte) (string, map[string]string, error) {
ct := http.DetectContentType(c)
contentType, args, err := mime.ParseMediaType(ct)
if err != nil {
return "", nil, err
}
return contentType, args, nil
}

View File

@@ -0,0 +1,13 @@
package httputils
import (
"testing"
)
func TestDetectContentType(t *testing.T) {
input := []byte("That is just a plain text")
if contentType, _, err := DetectContentType(input); err != nil || contentType != "text/plain" {
t.Errorf("TestDetectContentType failed")
}
}

View File

@@ -0,0 +1,95 @@
package httputils
import (
"fmt"
"io"
"net/http"
"time"
"github.com/Sirupsen/logrus"
)
type resumableRequestReader struct {
client *http.Client
request *http.Request
lastRange int64
totalSize int64
currentResponse *http.Response
failures uint32
maxFailures uint32
}
// ResumableRequestReader makes it possible to resume reading a request's body transparently
// maxfail is the number of times we retry to make requests again (not resumes)
// totalsize is the total length of the body; auto detect if not provided
func ResumableRequestReader(c *http.Client, r *http.Request, maxfail uint32, totalsize int64) io.ReadCloser {
return &resumableRequestReader{client: c, request: r, maxFailures: maxfail, totalSize: totalsize}
}
// ResumableRequestReaderWithInitialResponse makes it possible to resume
// reading the body of an already initiated request.
func ResumableRequestReaderWithInitialResponse(c *http.Client, r *http.Request, maxfail uint32, totalsize int64, initialResponse *http.Response) io.ReadCloser {
return &resumableRequestReader{client: c, request: r, maxFailures: maxfail, totalSize: totalsize, currentResponse: initialResponse}
}
func (r *resumableRequestReader) Read(p []byte) (n int, err error) {
if r.client == nil || r.request == nil {
return 0, fmt.Errorf("client and request can't be nil\n")
}
isFreshRequest := false
if r.lastRange != 0 && r.currentResponse == nil {
readRange := fmt.Sprintf("bytes=%d-%d", r.lastRange, r.totalSize)
r.request.Header.Set("Range", readRange)
time.Sleep(5 * time.Second)
}
if r.currentResponse == nil {
r.currentResponse, err = r.client.Do(r.request)
isFreshRequest = true
}
if err != nil && r.failures+1 != r.maxFailures {
r.cleanUpResponse()
r.failures++
time.Sleep(5 * time.Duration(r.failures) * time.Second)
return 0, nil
} else if err != nil {
r.cleanUpResponse()
return 0, err
}
if r.currentResponse.StatusCode == 416 && r.lastRange == r.totalSize && r.currentResponse.ContentLength == 0 {
r.cleanUpResponse()
return 0, io.EOF
} else if r.currentResponse.StatusCode != 206 && r.lastRange != 0 && isFreshRequest {
r.cleanUpResponse()
return 0, fmt.Errorf("the server doesn't support byte ranges")
}
if r.totalSize == 0 {
r.totalSize = r.currentResponse.ContentLength
} else if r.totalSize <= 0 {
r.cleanUpResponse()
return 0, fmt.Errorf("failed to auto detect content length")
}
n, err = r.currentResponse.Body.Read(p)
r.lastRange += int64(n)
if err != nil {
r.cleanUpResponse()
}
if err != nil && err != io.EOF {
logrus.Infof("encountered error during pull and clearing it before resume: %s", err)
err = nil
}
return n, err
}
func (r *resumableRequestReader) Close() error {
r.cleanUpResponse()
r.client = nil
r.request = nil
return nil
}
func (r *resumableRequestReader) cleanUpResponse() {
if r.currentResponse != nil {
r.currentResponse.Body.Close()
r.currentResponse = nil
}
}

View File

@@ -0,0 +1,307 @@
package httputils
import (
"fmt"
"io"
"io/ioutil"
"net/http"
"net/http/httptest"
"strings"
"testing"
)
func TestResumableRequestHeaderSimpleErrors(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, "Hello, world !")
}))
defer ts.Close()
client := &http.Client{}
var req *http.Request
req, err := http.NewRequest("GET", ts.URL, nil)
if err != nil {
t.Fatal(err)
}
expectedError := "client and request can't be nil\n"
resreq := &resumableRequestReader{}
_, err = resreq.Read([]byte{})
if err == nil || err.Error() != expectedError {
t.Fatalf("Expected an error with '%s', got %v.", expectedError, err)
}
resreq = &resumableRequestReader{
client: client,
request: req,
totalSize: -1,
}
expectedError = "failed to auto detect content length"
_, err = resreq.Read([]byte{})
if err == nil || err.Error() != expectedError {
t.Fatalf("Expected an error with '%s', got %v.", expectedError, err)
}
}
// Not too much failures, bails out after some wait
func TestResumableRequestHeaderNotTooMuchFailures(t *testing.T) {
client := &http.Client{}
var badReq *http.Request
badReq, err := http.NewRequest("GET", "I'm not an url", nil)
if err != nil {
t.Fatal(err)
}
resreq := &resumableRequestReader{
client: client,
request: badReq,
failures: 0,
maxFailures: 2,
}
read, err := resreq.Read([]byte{})
if err != nil || read != 0 {
t.Fatalf("Expected no error and no byte read, got err:%v, read:%v.", err, read)
}
}
// Too much failures, returns the error
func TestResumableRequestHeaderTooMuchFailures(t *testing.T) {
client := &http.Client{}
var badReq *http.Request
badReq, err := http.NewRequest("GET", "I'm not an url", nil)
if err != nil {
t.Fatal(err)
}
resreq := &resumableRequestReader{
client: client,
request: badReq,
failures: 0,
maxFailures: 1,
}
defer resreq.Close()
expectedError := `Get I%27m%20not%20an%20url: unsupported protocol scheme ""`
read, err := resreq.Read([]byte{})
if err == nil || err.Error() != expectedError || read != 0 {
t.Fatalf("Expected the error '%s', got err:%v, read:%v.", expectedError, err, read)
}
}
type errorReaderCloser struct{}
func (errorReaderCloser) Close() error { return nil }
func (errorReaderCloser) Read(p []byte) (n int, err error) {
return 0, fmt.Errorf("A error occured")
}
// If a an unknown error is encountered, return 0, nil and log it
func TestResumableRequestReaderWithReadError(t *testing.T) {
var req *http.Request
req, err := http.NewRequest("GET", "", nil)
if err != nil {
t.Fatal(err)
}
client := &http.Client{}
response := &http.Response{
Status: "500 Internal Server",
StatusCode: 500,
ContentLength: 0,
Close: true,
Body: errorReaderCloser{},
}
resreq := &resumableRequestReader{
client: client,
request: req,
currentResponse: response,
lastRange: 1,
totalSize: 1,
}
defer resreq.Close()
buf := make([]byte, 1)
read, err := resreq.Read(buf)
if err != nil {
t.Fatal(err)
}
if read != 0 {
t.Fatalf("Expected to have read nothing, but read %v", read)
}
}
func TestResumableRequestReaderWithEOFWith416Response(t *testing.T) {
var req *http.Request
req, err := http.NewRequest("GET", "", nil)
if err != nil {
t.Fatal(err)
}
client := &http.Client{}
response := &http.Response{
Status: "416 Requested Range Not Satisfiable",
StatusCode: 416,
ContentLength: 0,
Close: true,
Body: ioutil.NopCloser(strings.NewReader("")),
}
resreq := &resumableRequestReader{
client: client,
request: req,
currentResponse: response,
lastRange: 1,
totalSize: 1,
}
defer resreq.Close()
buf := make([]byte, 1)
_, err = resreq.Read(buf)
if err == nil || err != io.EOF {
t.Fatalf("Expected an io.EOF error, got %v", err)
}
}
func TestResumableRequestReaderWithServerDoesntSupportByteRanges(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Header.Get("Range") == "" {
t.Fatalf("Expected a Range HTTP header, got nothing")
}
}))
defer ts.Close()
var req *http.Request
req, err := http.NewRequest("GET", ts.URL, nil)
if err != nil {
t.Fatal(err)
}
client := &http.Client{}
resreq := &resumableRequestReader{
client: client,
request: req,
lastRange: 1,
}
defer resreq.Close()
buf := make([]byte, 2)
_, err = resreq.Read(buf)
if err == nil || err.Error() != "the server doesn't support byte ranges" {
t.Fatalf("Expected an error 'the server doesn't support byte ranges', got %v", err)
}
}
func TestResumableRequestReaderWithZeroTotalSize(t *testing.T) {
srvtxt := "some response text data"
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, srvtxt)
}))
defer ts.Close()
var req *http.Request
req, err := http.NewRequest("GET", ts.URL, nil)
if err != nil {
t.Fatal(err)
}
client := &http.Client{}
retries := uint32(5)
resreq := ResumableRequestReader(client, req, retries, 0)
defer resreq.Close()
data, err := ioutil.ReadAll(resreq)
if err != nil {
t.Fatal(err)
}
resstr := strings.TrimSuffix(string(data), "\n")
if resstr != srvtxt {
t.Errorf("resstr != srvtxt")
}
}
func TestResumableRequestReader(t *testing.T) {
srvtxt := "some response text data"
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, srvtxt)
}))
defer ts.Close()
var req *http.Request
req, err := http.NewRequest("GET", ts.URL, nil)
if err != nil {
t.Fatal(err)
}
client := &http.Client{}
retries := uint32(5)
imgSize := int64(len(srvtxt))
resreq := ResumableRequestReader(client, req, retries, imgSize)
defer resreq.Close()
data, err := ioutil.ReadAll(resreq)
if err != nil {
t.Fatal(err)
}
resstr := strings.TrimSuffix(string(data), "\n")
if resstr != srvtxt {
t.Errorf("resstr != srvtxt")
}
}
func TestResumableRequestReaderWithInitialResponse(t *testing.T) {
srvtxt := "some response text data"
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, srvtxt)
}))
defer ts.Close()
var req *http.Request
req, err := http.NewRequest("GET", ts.URL, nil)
if err != nil {
t.Fatal(err)
}
client := &http.Client{}
retries := uint32(5)
imgSize := int64(len(srvtxt))
res, err := client.Do(req)
if err != nil {
t.Fatal(err)
}
resreq := ResumableRequestReaderWithInitialResponse(client, req, retries, imgSize, res)
defer resreq.Close()
data, err := ioutil.ReadAll(resreq)
if err != nil {
t.Fatal(err)
}
resstr := strings.TrimSuffix(string(data), "\n")
if resstr != srvtxt {
t.Errorf("resstr != srvtxt")
}
}

Some files were not shown because too many files have changed in this diff Show More