Fix the dependency issue (#231)

This commit is contained in:
Robbie Zhang
2018-06-21 12:09:42 -07:00
committed by GitHub
parent 027b76651d
commit 6ec1098bb8
16629 changed files with 74837 additions and 4975021 deletions

View File

@@ -1,12 +0,0 @@
# go-ansiterm
This is a cross platform Ansi Terminal Emulation library. It reads a stream of Ansi characters and produces the appropriate function calls. The results of the function calls are platform dependent.
For example the parser might receive "ESC, [, A" as a stream of three characters. This is the code for Cursor Up (http://www.vt100.net/docs/vt510-rm/CUU). The parser then calls the cursor up function (CUU()) on an event handler. The event handler determines what platform specific work must be done to cause the cursor to move up one position.
The parser (parser.go) is a partial implementation of this state machine (http://vt100.net/emu/vt500_parser.png). There are also two event handler implementations, one for tests (test_event_handler.go) to validate that the expected events are being produced and called, the other is a Windows implementation (winterm/win_event_handler.go).
See parser_test.go for examples exercising the state machine and generating appropriate function calls.
-----
This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments.

View File

@@ -1,141 +0,0 @@
package ansiterm
import (
"fmt"
"testing"
)
func TestStateTransitions(t *testing.T) {
stateTransitionHelper(t, "CsiEntry", "Ground", alphabetics)
stateTransitionHelper(t, "CsiEntry", "CsiParam", csiCollectables)
stateTransitionHelper(t, "Escape", "CsiEntry", []byte{ANSI_ESCAPE_SECONDARY})
stateTransitionHelper(t, "Escape", "OscString", []byte{0x5D})
stateTransitionHelper(t, "Escape", "Ground", escapeToGroundBytes)
stateTransitionHelper(t, "Escape", "EscapeIntermediate", intermeds)
stateTransitionHelper(t, "EscapeIntermediate", "EscapeIntermediate", intermeds)
stateTransitionHelper(t, "EscapeIntermediate", "EscapeIntermediate", executors)
stateTransitionHelper(t, "EscapeIntermediate", "Ground", escapeIntermediateToGroundBytes)
stateTransitionHelper(t, "OscString", "Ground", []byte{ANSI_BEL})
stateTransitionHelper(t, "OscString", "Ground", []byte{0x5C})
stateTransitionHelper(t, "Ground", "Ground", executors)
}
func TestAnyToX(t *testing.T) {
anyToXHelper(t, []byte{ANSI_ESCAPE_PRIMARY}, "Escape")
anyToXHelper(t, []byte{DCS_ENTRY}, "DcsEntry")
anyToXHelper(t, []byte{OSC_STRING}, "OscString")
anyToXHelper(t, []byte{CSI_ENTRY}, "CsiEntry")
anyToXHelper(t, toGroundBytes, "Ground")
}
func TestCollectCsiParams(t *testing.T) {
parser, _ := createTestParser("CsiEntry")
parser.Parse(csiCollectables)
buffer := parser.context.paramBuffer
bufferCount := len(buffer)
if bufferCount != len(csiCollectables) {
t.Errorf("Buffer: %v", buffer)
t.Errorf("CsiParams: %v", csiCollectables)
t.Errorf("Buffer count failure: %d != %d", bufferCount, len(csiParams))
return
}
for i, v := range csiCollectables {
if v != buffer[i] {
t.Errorf("Buffer: %v", buffer)
t.Errorf("CsiParams: %v", csiParams)
t.Errorf("Mismatch at buffer[%d] = %d", i, buffer[i])
}
}
}
func TestParseParams(t *testing.T) {
parseParamsHelper(t, []byte{}, []string{})
parseParamsHelper(t, []byte{';'}, []string{})
parseParamsHelper(t, []byte{';', ';'}, []string{})
parseParamsHelper(t, []byte{'7'}, []string{"7"})
parseParamsHelper(t, []byte{'7', ';'}, []string{"7"})
parseParamsHelper(t, []byte{'7', ';', ';'}, []string{"7"})
parseParamsHelper(t, []byte{'7', ';', ';', '8'}, []string{"7", "8"})
parseParamsHelper(t, []byte{'7', ';', '8', ';'}, []string{"7", "8"})
parseParamsHelper(t, []byte{'7', ';', ';', '8', ';', ';'}, []string{"7", "8"})
parseParamsHelper(t, []byte{'7', '8'}, []string{"78"})
parseParamsHelper(t, []byte{'7', '8', ';'}, []string{"78"})
parseParamsHelper(t, []byte{'7', '8', ';', '9', '0'}, []string{"78", "90"})
parseParamsHelper(t, []byte{'7', '8', ';', ';', '9', '0'}, []string{"78", "90"})
parseParamsHelper(t, []byte{'7', '8', ';', '9', '0', ';'}, []string{"78", "90"})
parseParamsHelper(t, []byte{'7', '8', ';', '9', '0', ';', ';'}, []string{"78", "90"})
}
func TestCursor(t *testing.T) {
cursorSingleParamHelper(t, 'A', "CUU")
cursorSingleParamHelper(t, 'B', "CUD")
cursorSingleParamHelper(t, 'C', "CUF")
cursorSingleParamHelper(t, 'D', "CUB")
cursorSingleParamHelper(t, 'E', "CNL")
cursorSingleParamHelper(t, 'F', "CPL")
cursorSingleParamHelper(t, 'G', "CHA")
cursorTwoParamHelper(t, 'H', "CUP")
cursorTwoParamHelper(t, 'f', "HVP")
funcCallParamHelper(t, []byte{'?', '2', '5', 'h'}, "CsiEntry", "Ground", []string{"DECTCEM([true])"})
funcCallParamHelper(t, []byte{'?', '2', '5', 'l'}, "CsiEntry", "Ground", []string{"DECTCEM([false])"})
}
func TestErase(t *testing.T) {
// Erase in Display
eraseHelper(t, 'J', "ED")
// Erase in Line
eraseHelper(t, 'K', "EL")
}
func TestSelectGraphicRendition(t *testing.T) {
funcCallParamHelper(t, []byte{'m'}, "CsiEntry", "Ground", []string{"SGR([0])"})
funcCallParamHelper(t, []byte{'0', 'm'}, "CsiEntry", "Ground", []string{"SGR([0])"})
funcCallParamHelper(t, []byte{'0', ';', '1', 'm'}, "CsiEntry", "Ground", []string{"SGR([0 1])"})
funcCallParamHelper(t, []byte{'0', ';', '1', ';', '2', 'm'}, "CsiEntry", "Ground", []string{"SGR([0 1 2])"})
}
func TestScroll(t *testing.T) {
scrollHelper(t, 'S', "SU")
scrollHelper(t, 'T', "SD")
}
func TestPrint(t *testing.T) {
parser, evtHandler := createTestParser("Ground")
parser.Parse(printables)
validateState(t, parser.currState, "Ground")
for i, v := range printables {
expectedCall := fmt.Sprintf("Print([%s])", string(v))
actualCall := evtHandler.FunctionCalls[i]
if actualCall != expectedCall {
t.Errorf("Actual != Expected: %v != %v at %d", actualCall, expectedCall, i)
}
}
}
func TestClear(t *testing.T) {
p, _ := createTestParser("Ground")
fillContext(p.context)
p.clear()
validateEmptyContext(t, p.context)
}
func TestClearOnStateChange(t *testing.T) {
clearOnStateChangeHelper(t, "Ground", "Escape", []byte{ANSI_ESCAPE_PRIMARY})
clearOnStateChangeHelper(t, "Ground", "CsiEntry", []byte{CSI_ENTRY})
}
func TestC0(t *testing.T) {
expectedCall := "Execute([" + string(ANSI_LINE_FEED) + "])"
c0Helper(t, []byte{ANSI_LINE_FEED}, "Ground", []string{expectedCall})
expectedCall = "Execute([" + string(ANSI_CARRIAGE_RETURN) + "])"
c0Helper(t, []byte{ANSI_CARRIAGE_RETURN}, "Ground", []string{expectedCall})
}
func TestEscDispatch(t *testing.T) {
funcCallParamHelper(t, []byte{'M'}, "Escape", "Ground", []string{"RI([])"})
}

View File

@@ -1,114 +0,0 @@
package ansiterm
import (
"fmt"
"testing"
)
func getStateNames() []string {
parser, _ := createTestParser("Ground")
stateNames := []string{}
for _, state := range parser.stateMap {
stateNames = append(stateNames, state.Name())
}
return stateNames
}
func stateTransitionHelper(t *testing.T, start string, end string, bytes []byte) {
for _, b := range bytes {
bytes := []byte{byte(b)}
parser, _ := createTestParser(start)
parser.Parse(bytes)
validateState(t, parser.currState, end)
}
}
func anyToXHelper(t *testing.T, bytes []byte, expectedState string) {
for _, s := range getStateNames() {
stateTransitionHelper(t, s, expectedState, bytes)
}
}
func funcCallParamHelper(t *testing.T, bytes []byte, start string, expected string, expectedCalls []string) {
parser, evtHandler := createTestParser(start)
parser.Parse(bytes)
validateState(t, parser.currState, expected)
validateFuncCalls(t, evtHandler.FunctionCalls, expectedCalls)
}
func parseParamsHelper(t *testing.T, bytes []byte, expectedParams []string) {
params, err := parseParams(bytes)
if err != nil {
t.Errorf("Parameter parse error: %v", err)
return
}
if len(params) != len(expectedParams) {
t.Errorf("Parsed parameters: %v", params)
t.Errorf("Expected parameters: %v", expectedParams)
t.Errorf("Parameter length failure: %d != %d", len(params), len(expectedParams))
return
}
for i, v := range expectedParams {
if v != params[i] {
t.Errorf("Parsed parameters: %v", params)
t.Errorf("Expected parameters: %v", expectedParams)
t.Errorf("Parameter parse failure: %s != %s at position %d", v, params[i], i)
}
}
}
func cursorSingleParamHelper(t *testing.T, command byte, funcName string) {
funcCallParamHelper(t, []byte{command}, "CsiEntry", "Ground", []string{fmt.Sprintf("%s([1])", funcName)})
funcCallParamHelper(t, []byte{'0', command}, "CsiEntry", "Ground", []string{fmt.Sprintf("%s([1])", funcName)})
funcCallParamHelper(t, []byte{'2', command}, "CsiEntry", "Ground", []string{fmt.Sprintf("%s([2])", funcName)})
funcCallParamHelper(t, []byte{'2', '3', command}, "CsiEntry", "Ground", []string{fmt.Sprintf("%s([23])", funcName)})
funcCallParamHelper(t, []byte{'2', ';', '3', command}, "CsiEntry", "Ground", []string{fmt.Sprintf("%s([2])", funcName)})
funcCallParamHelper(t, []byte{'2', ';', '3', ';', '4', command}, "CsiEntry", "Ground", []string{fmt.Sprintf("%s([2])", funcName)})
}
func cursorTwoParamHelper(t *testing.T, command byte, funcName string) {
funcCallParamHelper(t, []byte{command}, "CsiEntry", "Ground", []string{fmt.Sprintf("%s([1 1])", funcName)})
funcCallParamHelper(t, []byte{'0', command}, "CsiEntry", "Ground", []string{fmt.Sprintf("%s([1 1])", funcName)})
funcCallParamHelper(t, []byte{'2', command}, "CsiEntry", "Ground", []string{fmt.Sprintf("%s([2 1])", funcName)})
funcCallParamHelper(t, []byte{'2', '3', command}, "CsiEntry", "Ground", []string{fmt.Sprintf("%s([23 1])", funcName)})
funcCallParamHelper(t, []byte{'2', ';', '3', command}, "CsiEntry", "Ground", []string{fmt.Sprintf("%s([2 3])", funcName)})
funcCallParamHelper(t, []byte{'2', ';', '3', ';', '4', command}, "CsiEntry", "Ground", []string{fmt.Sprintf("%s([2 3])", funcName)})
}
func eraseHelper(t *testing.T, command byte, funcName string) {
funcCallParamHelper(t, []byte{command}, "CsiEntry", "Ground", []string{fmt.Sprintf("%s([0])", funcName)})
funcCallParamHelper(t, []byte{'0', command}, "CsiEntry", "Ground", []string{fmt.Sprintf("%s([0])", funcName)})
funcCallParamHelper(t, []byte{'1', command}, "CsiEntry", "Ground", []string{fmt.Sprintf("%s([1])", funcName)})
funcCallParamHelper(t, []byte{'2', command}, "CsiEntry", "Ground", []string{fmt.Sprintf("%s([2])", funcName)})
funcCallParamHelper(t, []byte{'3', command}, "CsiEntry", "Ground", []string{fmt.Sprintf("%s([3])", funcName)})
funcCallParamHelper(t, []byte{'4', command}, "CsiEntry", "Ground", []string{fmt.Sprintf("%s([0])", funcName)})
funcCallParamHelper(t, []byte{'1', ';', '2', command}, "CsiEntry", "Ground", []string{fmt.Sprintf("%s([1])", funcName)})
}
func scrollHelper(t *testing.T, command byte, funcName string) {
funcCallParamHelper(t, []byte{command}, "CsiEntry", "Ground", []string{fmt.Sprintf("%s([1])", funcName)})
funcCallParamHelper(t, []byte{'0', command}, "CsiEntry", "Ground", []string{fmt.Sprintf("%s([1])", funcName)})
funcCallParamHelper(t, []byte{'1', command}, "CsiEntry", "Ground", []string{fmt.Sprintf("%s([1])", funcName)})
funcCallParamHelper(t, []byte{'5', command}, "CsiEntry", "Ground", []string{fmt.Sprintf("%s([5])", funcName)})
funcCallParamHelper(t, []byte{'4', ';', '6', command}, "CsiEntry", "Ground", []string{fmt.Sprintf("%s([4])", funcName)})
}
func clearOnStateChangeHelper(t *testing.T, start string, end string, bytes []byte) {
p, _ := createTestParser(start)
fillContext(p.context)
p.Parse(bytes)
validateState(t, p.currState, end)
validateEmptyContext(t, p.context)
}
func c0Helper(t *testing.T, bytes []byte, expectedState string, expectedCalls []string) {
parser, evtHandler := createTestParser("Ground")
parser.Parse(bytes)
validateState(t, parser.currState, expectedState)
validateFuncCalls(t, evtHandler.FunctionCalls, expectedCalls)
}

View File

@@ -1,66 +0,0 @@
package ansiterm
import (
"testing"
)
func createTestParser(s string) (*AnsiParser, *TestAnsiEventHandler) {
evtHandler := CreateTestAnsiEventHandler()
parser := CreateParser(s, evtHandler)
return parser, evtHandler
}
func validateState(t *testing.T, actualState state, expectedStateName string) {
actualName := "Nil"
if actualState != nil {
actualName = actualState.Name()
}
if actualName != expectedStateName {
t.Errorf("Invalid state: '%s' != '%s'", actualName, expectedStateName)
}
}
func validateFuncCalls(t *testing.T, actualCalls []string, expectedCalls []string) {
actualCount := len(actualCalls)
expectedCount := len(expectedCalls)
if actualCount != expectedCount {
t.Errorf("Actual calls: %v", actualCalls)
t.Errorf("Expected calls: %v", expectedCalls)
t.Errorf("Call count error: %d != %d", actualCount, expectedCount)
return
}
for i, v := range actualCalls {
if v != expectedCalls[i] {
t.Errorf("Actual calls: %v", actualCalls)
t.Errorf("Expected calls: %v", expectedCalls)
t.Errorf("Mismatched calls: %s != %s with lengths %d and %d", v, expectedCalls[i], len(v), len(expectedCalls[i]))
}
}
}
func fillContext(context *ansiContext) {
context.currentChar = 'A'
context.paramBuffer = []byte{'C', 'D', 'E'}
context.interBuffer = []byte{'F', 'G', 'H'}
}
func validateEmptyContext(t *testing.T, context *ansiContext) {
var expectedCurrChar byte = 0x0
if context.currentChar != expectedCurrChar {
t.Errorf("Currentchar mismatch '%#x' != '%#x'", context.currentChar, expectedCurrChar)
}
if len(context.paramBuffer) != 0 {
t.Errorf("Non-empty parameter buffer: %v", context.paramBuffer)
}
if len(context.paramBuffer) != 0 {
t.Errorf("Non-empty intermediate buffer: %v", context.interBuffer)
}
}

View File

@@ -1,173 +0,0 @@
package ansiterm
import (
"fmt"
"strconv"
)
type TestAnsiEventHandler struct {
FunctionCalls []string
}
func CreateTestAnsiEventHandler() *TestAnsiEventHandler {
evtHandler := TestAnsiEventHandler{}
evtHandler.FunctionCalls = make([]string, 0)
return &evtHandler
}
func (h *TestAnsiEventHandler) recordCall(call string, params []string) {
s := fmt.Sprintf("%s(%v)", call, params)
h.FunctionCalls = append(h.FunctionCalls, s)
}
func (h *TestAnsiEventHandler) Print(b byte) error {
h.recordCall("Print", []string{string(b)})
return nil
}
func (h *TestAnsiEventHandler) Execute(b byte) error {
h.recordCall("Execute", []string{string(b)})
return nil
}
func (h *TestAnsiEventHandler) CUU(param int) error {
h.recordCall("CUU", []string{strconv.Itoa(param)})
return nil
}
func (h *TestAnsiEventHandler) CUD(param int) error {
h.recordCall("CUD", []string{strconv.Itoa(param)})
return nil
}
func (h *TestAnsiEventHandler) CUF(param int) error {
h.recordCall("CUF", []string{strconv.Itoa(param)})
return nil
}
func (h *TestAnsiEventHandler) CUB(param int) error {
h.recordCall("CUB", []string{strconv.Itoa(param)})
return nil
}
func (h *TestAnsiEventHandler) CNL(param int) error {
h.recordCall("CNL", []string{strconv.Itoa(param)})
return nil
}
func (h *TestAnsiEventHandler) CPL(param int) error {
h.recordCall("CPL", []string{strconv.Itoa(param)})
return nil
}
func (h *TestAnsiEventHandler) CHA(param int) error {
h.recordCall("CHA", []string{strconv.Itoa(param)})
return nil
}
func (h *TestAnsiEventHandler) VPA(param int) error {
h.recordCall("VPA", []string{strconv.Itoa(param)})
return nil
}
func (h *TestAnsiEventHandler) CUP(x int, y int) error {
xS, yS := strconv.Itoa(x), strconv.Itoa(y)
h.recordCall("CUP", []string{xS, yS})
return nil
}
func (h *TestAnsiEventHandler) HVP(x int, y int) error {
xS, yS := strconv.Itoa(x), strconv.Itoa(y)
h.recordCall("HVP", []string{xS, yS})
return nil
}
func (h *TestAnsiEventHandler) DECTCEM(visible bool) error {
h.recordCall("DECTCEM", []string{strconv.FormatBool(visible)})
return nil
}
func (h *TestAnsiEventHandler) DECOM(visible bool) error {
h.recordCall("DECOM", []string{strconv.FormatBool(visible)})
return nil
}
func (h *TestAnsiEventHandler) DECCOLM(use132 bool) error {
h.recordCall("DECOLM", []string{strconv.FormatBool(use132)})
return nil
}
func (h *TestAnsiEventHandler) ED(param int) error {
h.recordCall("ED", []string{strconv.Itoa(param)})
return nil
}
func (h *TestAnsiEventHandler) EL(param int) error {
h.recordCall("EL", []string{strconv.Itoa(param)})
return nil
}
func (h *TestAnsiEventHandler) IL(param int) error {
h.recordCall("IL", []string{strconv.Itoa(param)})
return nil
}
func (h *TestAnsiEventHandler) DL(param int) error {
h.recordCall("DL", []string{strconv.Itoa(param)})
return nil
}
func (h *TestAnsiEventHandler) ICH(param int) error {
h.recordCall("ICH", []string{strconv.Itoa(param)})
return nil
}
func (h *TestAnsiEventHandler) DCH(param int) error {
h.recordCall("DCH", []string{strconv.Itoa(param)})
return nil
}
func (h *TestAnsiEventHandler) SGR(params []int) error {
strings := []string{}
for _, v := range params {
strings = append(strings, strconv.Itoa(v))
}
h.recordCall("SGR", strings)
return nil
}
func (h *TestAnsiEventHandler) SU(param int) error {
h.recordCall("SU", []string{strconv.Itoa(param)})
return nil
}
func (h *TestAnsiEventHandler) SD(param int) error {
h.recordCall("SD", []string{strconv.Itoa(param)})
return nil
}
func (h *TestAnsiEventHandler) DA(params []string) error {
h.recordCall("DA", params)
return nil
}
func (h *TestAnsiEventHandler) DECSTBM(top int, bottom int) error {
topS, bottomS := strconv.Itoa(top), strconv.Itoa(bottom)
h.recordCall("DECSTBM", []string{topS, bottomS})
return nil
}
func (h *TestAnsiEventHandler) RI() error {
h.recordCall("RI", nil)
return nil
}
func (h *TestAnsiEventHandler) IND() error {
h.recordCall("IND", nil)
return nil
}
func (h *TestAnsiEventHandler) Flush() error {
return nil
}

View File

@@ -1,7 +0,0 @@
Thank you for your contribution to Go-AutoRest! We will triage and review it as soon as we can.
As part of submitting, please make sure you can make the following assertions:
- [ ] I've tested my changes, adding unit tests if applicable.
- [ ] I've added Apache 2.0 Headers to the top of any new source files.
- [ ] I'm submitting this PR to the `dev` branch, except in the case of urgent bug fixes warranting their own release.
- [ ] If I'm targeting `master`, I've updated [CHANGELOG.md](https://github.com/Azure/go-autorest/blob/master/CHANGELOG.md) to address the changes I'm making.

View File

@@ -1,31 +0,0 @@
# The standard Go .gitignore file follows. (Sourced from: github.com/github/gitignore/master/Go.gitignore)
# Compiled Object files, Static and Dynamic libs (Shared Objects)
*.o
*.a
*.so
# Folders
_obj
_test
.DS_Store
.idea/
# Architecture specific extensions/prefixes
*.[568vq]
[568vq].out
*.cgo1.go
*.cgo2.c
_cgo_defun.c
_cgo_gotypes.go
_cgo_export.*
_testmain.go
*.exe
*.test
*.prof
# go-autorest specific
vendor/
autorest/azure/example/example

View File

@@ -1,24 +0,0 @@
sudo: false
language: go
go:
- 1.9
- 1.8
- 1.7
install:
- go get -u github.com/golang/lint/golint
- go get -u github.com/Masterminds/glide
- go get -u github.com/stretchr/testify
- go get -u github.com/GoASTScanner/gas
- glide install
script:
- grep -L -r --include *.go --exclude-dir vendor -P "Copyright (\d{4}|\(c\)) Microsoft" ./ | tee /dev/stderr | test -z "$(< /dev/stdin)"
- test -z "$(gofmt -s -l -w ./autorest/. | tee /dev/stderr)"
- test -z "$(golint ./autorest/... | tee /dev/stderr)"
- go vet ./autorest/...
- test -z "$(gas ./autorest/... | tee /dev/stderr | grep Error)"
- go build -v ./autorest/...
- go test -v ./autorest/...

View File

@@ -1,256 +0,0 @@
# CHANGELOG
## v9.4.1
### Bug Fixes
- Update the AccessTokensPath() to read access tokens path through AZURE_ACCESS_TOKEN_FILE. If this
environment variable is not set, it will fall back to use default path set by Azure CLI.
- Use case-insensitive string comparison for polling states.
## v9.4.0
### New Features
- Added WaitForCompletion() to Future as a default polling implementation.
### Bug Fixes
- Method Future.Done() shouldn't update polling status for unexpected HTTP status codes.
## v9.3.1
### Bug Fixes
- DoRetryForStatusCodes will retry if sender.Do returns a non-nil error.
## v9.3.0
### New Features
- Added PollingMethod() to Future so callers know what kind of polling mechanism is used.
- Added azure.ChangeToGet() which transforms an http.Request into a GET (to be used with LROs).
## v9.2.0
### New Features
- Added support for custom Azure Stack endpoints.
- Added type azure.Future used to track the status of long-running operations.
### Bug Fixes
- Preserve the original error in DoRetryWithRegistration when registration fails.
## v9.1.1
- Fixes a bug regarding the cookie jar on `autorest.Client.Sender`.
## v9.1.0
### New Features
- In cases where there is a non-empty error from the service, attempt to unmarshal it instead of uniformly calling it an "Unknown" error.
- Support for loading Azure CLI Authentication files.
- Automatically register your subscription with the Azure Resource Provider if it hadn't been previously.
### Bug Fixes
- RetriableRequest can now tolerate a ReadSeekable body being read but not reset.
- Adding missing Apache Headers
## v9.0.0
> **IMPORTANT:** This release was intially labeled incorrectly as `v8.4.0`. From the time it was released, it should have been marked `v9.0.0` because it contains breaking changes to the MSI packages. We appologize for any inconvenience this causes.
Adding MSI Endpoint Support and CLI token rehydration.
## v8.3.1
Pick up bug fix in adal for MSI support.
## v8.3.0
Updates to Error string formats for clarity. Also, adding a copy of the http.Response to errors for an improved debugging experience.
## v8.2.0
### New Features
- Add support for bearer authentication callbacks
- Support 429 response codes that include "Retry-After" header
- Support validation constraint "Pattern" for map keys
### Bug Fixes
- Make RetriableRequest work with multiple versions of Go
## v8.1.1
Updates the RetriableRequest to take advantage of GetBody() added in Go 1.8.
## v8.1.0
Adds RetriableRequest type for more efficient handling of retrying HTTP requests.
## v8.0.0
ADAL refactored into its own package.
Support for UNIX time.
## v7.3.1
- Version Testing now removed from production bits that are shipped with the library.
## v7.3.0
- Exposing new `RespondDecorator`, `ByDiscardingBody`. This allows operations
to acknowledge that they do not need either the entire or a trailing portion
of accepts response body. In doing so, Go's http library can reuse HTTP
connections more readily.
- Adding `PrepareDecorator` to target custom BaseURLs.
- Adding ACR suffix to public cloud environment.
- Updating Glide dependencies.
## v7.2.5
- Fixed the Active Directory endpoint for the China cloud.
- Removes UTF-8 BOM if present in response payload.
- Added telemetry.
## v7.2.3
- Fixing bug in calls to `DelayForBackoff` that caused doubling of delay
duration.
## v7.2.2
- autorest/azure: added ASM and ARM VM DNS suffixes.
## v7.2.1
- fixed parsing of UTC times that are not RFC3339 conformant.
## v7.2.0
- autorest/validation: Reformat validation error for better error message.
## v7.1.0
- preparer: Added support for multipart formdata - WithMultiPartFormdata()
- preparer: Added support for sending file in request body - WithFile
- client: Added RetryDuration parameter.
- autorest/validation: new package for validation code for Azure Go SDK.
## v7.0.7
- Add trailing / to endpoint
- azure: add EnvironmentFromName
## v7.0.6
- Add retry logic for 408, 500, 502, 503 and 504 status codes.
- Change url path and query encoding logic.
- Fix DelayForBackoff for proper exponential delay.
- Add CookieJar in Client.
## v7.0.5
- Add check to start polling only when status is in [200,201,202].
- Refactoring for unchecked errors.
- azure/persist changes.
- Fix 'file in use' issue in renewing token in deviceflow.
- Store header RetryAfter for subsequent requests in polling.
- Add attribute details in service error.
## v7.0.4
- Better error messages for long running operation failures
## v7.0.3
- Corrected DoPollForAsynchronous to properly handle the initial response
## v7.0.2
- Corrected DoPollForAsynchronous to continue using the polling method first discovered
## v7.0.1
- Fixed empty JSON input error in ByUnmarshallingJSON
- Fixed polling support for GET calls
- Changed format name from TimeRfc1123 to TimeRFC1123
## v7.0.0
- Added ByCopying responder with supporting TeeReadCloser
- Rewrote Azure asynchronous handling
- Reverted to only unmarshalling JSON
- Corrected handling of RFC3339 time strings and added support for Rfc1123 time format
The `json.Decoder` does not catch bad data as thoroughly as `json.Unmarshal`. Since
`encoding/json` successfully deserializes all core types, and extended types normally provide
their custom JSON serialization handlers, the code has been reverted back to using
`json.Unmarshal`. The original change to use `json.Decode` was made to reduce duplicate
code; there is no loss of function, and there is a gain in accuracy, by reverting.
Additionally, Azure services indicate requests to be polled by multiple means. The existing code
only checked for one of those (that is, the presence of the `Azure-AsyncOperation` header).
The new code correctly covers all cases and aligns with the other Azure SDKs.
## v6.1.0
- Introduced `date.ByUnmarshallingJSONDate` and `date.ByUnmarshallingJSONTime` to enable JSON encoded values.
## v6.0.0
- Completely reworked the handling of polled and asynchronous requests
- Removed unnecessary routines
- Reworked `mocks.Sender` to replay a series of `http.Response` objects
- Added `PrepareDecorators` for primitive types (e.g., bool, int32)
Handling polled and asynchronous requests is no longer part of `Client#Send`. Instead new
`SendDecorators` implement different styles of polled behavior. See`autorest.DoPollForStatusCodes`
and `azure.DoPollForAsynchronous` for examples.
## v5.0.0
- Added new RespondDecorators unmarshalling primitive types
- Corrected application of inspection and authorization PrependDecorators
## v4.0.0
- Added support for Azure long-running operations.
- Added cancelation support to all decorators and functions that may delay.
- Breaking: `DelayForBackoff` now accepts a channel, which may be nil.
## v3.1.0
- Add support for OAuth Device Flow authorization.
- Add support for ServicePrincipalTokens that are backed by an existing token, rather than other secret material.
- Add helpers for persisting and restoring Tokens.
- Increased code coverage in the github.com/Azure/autorest/azure package
## v3.0.0
- Breaking: `NewErrorWithError` no longer takes `statusCode int`.
- Breaking: `NewErrorWithStatusCode` is replaced with `NewErrorWithResponse`.
- Breaking: `Client#Send()` no longer takes `codes ...int` argument.
- Add: XML unmarshaling support with `ByUnmarshallingXML()`
- Stopped vending dependencies locally and switched to [Glide](https://github.com/Masterminds/glide).
Applications using this library should either use Glide or vendor dependencies locally some other way.
- Add: `azure.WithErrorUnlessStatusCode()` decorator to handle Azure errors.
- Fix: use `net/http.DefaultClient` as base client.
- Fix: Missing inspection for polling responses added.
- Add: CopyAndDecode helpers.
- Improved `./autorest/to` with `[]string` helpers.
- Removed golint suppressions in .travis.yml.
## v2.1.0
- Added `StatusCode` to `Error` for more easily obtaining the HTTP Reponse StatusCode (if any)
## v2.0.0
- Changed `to.StringMapPtr` method signature to return a pointer
- Changed `ServicePrincipalCertificateSecret` and `NewServicePrincipalTokenFromCertificate` to support generic certificate and private keys
## v1.0.0
- Added Logging inspectors to trace http.Request / Response
- Added support for User-Agent header
- Changed WithHeader PrepareDecorator to use set vs. add
- Added JSON to error when unmarshalling fails
- Added Client#Send method
- Corrected case of "Azure" in package paths
- Added "to" helpers, Azure helpers, and improved ease-of-use
- Corrected golint issues
## v1.0.1
- Added CHANGELOG.md
## v1.1.0
- Added mechanism to retrieve a ServicePrincipalToken using a certificate-signed JWT
- Added an example of creating a certificate-based ServicePrincipal and retrieving an OAuth token using the certificate
## v1.1.1
- Introduce godeps and vendor dependencies introduced in v1.1.1

View File

@@ -1,23 +0,0 @@
DIR?=./autorest/
default: build
build: fmt
go install $(DIR)
test:
go test $(DIR) || exit 1
vet:
@echo "go vet ."
@go vet $(DIR)... ; if [ $$? -eq 1 ]; then \
echo ""; \
echo "Vet found suspicious constructs. Please check the reported constructs"; \
echo "and fix them if necessary before submitting the code for review."; \
exit 1; \
fi
fmt:
gofmt -w $(DIR)
.PHONY: build test vet fmt

View File

@@ -1,132 +0,0 @@
# go-autorest
[![GoDoc](https://godoc.org/github.com/Azure/go-autorest/autorest?status.png)](https://godoc.org/github.com/Azure/go-autorest/autorest) [![Build Status](https://travis-ci.org/Azure/go-autorest.svg?branch=master)](https://travis-ci.org/Azure/go-autorest) [![Go Report Card](https://goreportcard.com/badge/Azure/go-autorest)](https://goreportcard.com/report/Azure/go-autorest)
## Usage
Package autorest implements an HTTP request pipeline suitable for use across multiple go-routines
and provides the shared routines relied on by AutoRest (see https://github.com/Azure/autorest/)
generated Go code.
The package breaks sending and responding to HTTP requests into three phases: Preparing, Sending,
and Responding. A typical pattern is:
```go
req, err := Prepare(&http.Request{},
token.WithAuthorization())
resp, err := Send(req,
WithLogging(logger),
DoErrorIfStatusCode(http.StatusInternalServerError),
DoCloseIfError(),
DoRetryForAttempts(5, time.Second))
err = Respond(resp,
ByDiscardingBody(),
ByClosing())
```
Each phase relies on decorators to modify and / or manage processing. Decorators may first modify
and then pass the data along, pass the data first and then modify the result, or wrap themselves
around passing the data (such as a logger might do). Decorators run in the order provided. For
example, the following:
```go
req, err := Prepare(&http.Request{},
WithBaseURL("https://microsoft.com/"),
WithPath("a"),
WithPath("b"),
WithPath("c"))
```
will set the URL to:
```
https://microsoft.com/a/b/c
```
Preparers and Responders may be shared and re-used (assuming the underlying decorators support
sharing and re-use). Performant use is obtained by creating one or more Preparers and Responders
shared among multiple go-routines, and a single Sender shared among multiple sending go-routines,
all bound together by means of input / output channels.
Decorators hold their passed state within a closure (such as the path components in the example
above). Be careful to share Preparers and Responders only in a context where such held state
applies. For example, it may not make sense to share a Preparer that applies a query string from a
fixed set of values. Similarly, sharing a Responder that reads the response body into a passed
struct (e.g., `ByUnmarshallingJson`) is likely incorrect.
Errors raised by autorest objects and methods will conform to the `autorest.Error` interface.
See the included examples for more detail. For details on the suggested use of this package by
generated clients, see the Client described below.
## Helpers
### Handling Swagger Dates
The Swagger specification (https://swagger.io) that drives AutoRest
(https://github.com/Azure/autorest/) precisely defines two date forms: date and date-time. The
github.com/Azure/go-autorest/autorest/date package provides time.Time derivations to ensure correct
parsing and formatting.
### Handling Empty Values
In JSON, missing values have different semantics than empty values. This is especially true for
services using the HTTP PATCH verb. The JSON submitted with a PATCH request generally contains
only those values to modify. Missing values are to be left unchanged. Developers, then, require a
means to both specify an empty value and to leave the value out of the submitted JSON.
The Go JSON package (`encoding/json`) supports the `omitempty` tag. When specified, it omits
empty values from the rendered JSON. Since Go defines default values for all base types (such as ""
for string and 0 for int) and provides no means to mark a value as actually empty, the JSON package
treats default values as meaning empty, omitting them from the rendered JSON. This means that, using
the Go base types encoded through the default JSON package, it is not possible to create JSON to
clear a value at the server.
The workaround within the Go community is to use pointers to base types in lieu of base types within
structures that map to JSON. For example, instead of a value of type `string`, the workaround uses
`*string`. While this enables distinguishing empty values from those to be unchanged, creating
pointers to a base type (notably constant, in-line values) requires additional variables. This, for
example,
```go
s := struct {
S *string
}{ S: &"foo" }
```
fails, while, this
```go
v := "foo"
s := struct {
S *string
}{ S: &v }
```
succeeds.
To ease using pointers, the subpackage `to` contains helpers that convert to and from pointers for
Go base types which have Swagger analogs. It also provides a helper that converts between
`map[string]string` and `map[string]*string`, enabling the JSON to specify that the value
associated with a key should be cleared. With the helpers, the previous example becomes
```go
s := struct {
S *string
}{ S: to.StringPtr("foo") }
```
## Install
```bash
go get github.com/Azure/go-autorest/autorest
go get github.com/Azure/go-autorest/autorest/azure
go get github.com/Azure/go-autorest/autorest/date
go get github.com/Azure/go-autorest/autorest/to
```
## License
See LICENSE file.
-----
This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments.

View File

@@ -1,253 +0,0 @@
# Azure Active Directory library for Go
This project provides a stand alone Azure Active Directory library for Go. The code was extracted
from [go-autorest](https://github.com/Azure/go-autorest/) project, which is used as a base for
[azure-sdk-for-go](https://github.com/Azure/azure-sdk-for-go).
## Installation
```
go get -u github.com/Azure/go-autorest/autorest/adal
```
## Usage
An Active Directory application is required in order to use this library. An application can be registered in the [Azure Portal](https://portal.azure.com/) follow these [guidelines](https://docs.microsoft.com/en-us/azure/active-directory/develop/active-directory-integrating-applications) or using the [Azure CLI](https://github.com/Azure/azure-cli).
### Register an Azure AD Application with secret
1. Register a new application with a `secret` credential
```
az ad app create \
--display-name example-app \
--homepage https://example-app/home \
--identifier-uris https://example-app/app \
--password secret
```
2. Create a service principal using the `Application ID` from previous step
```
az ad sp create --id "Application ID"
```
* Replace `Application ID` with `appId` from step 1.
### Register an Azure AD Application with certificate
1. Create a private key
```
openssl genrsa -out "example-app.key" 2048
```
2. Create the certificate
```
openssl req -new -key "example-app.key" -subj "/CN=example-app" -out "example-app.csr"
openssl x509 -req -in "example-app.csr" -signkey "example-app.key" -out "example-app.crt" -days 10000
```
3. Create the PKCS12 version of the certificate containing also the private key
```
openssl pkcs12 -export -out "example-app.pfx" -inkey "example-app.key" -in "example-app.crt" -passout pass:
```
4. Register a new application with the certificate content form `example-app.crt`
```
certificateContents="$(tail -n+2 "example-app.crt" | head -n-1)"
az ad app create \
--display-name example-app \
--homepage https://example-app/home \
--identifier-uris https://example-app/app \
--key-usage Verify --end-date 2018-01-01 \
--key-value "${certificateContents}"
```
5. Create a service principal using the `Application ID` from previous step
```
az ad sp create --id "APPLICATION_ID"
```
* Replace `APPLICATION_ID` with `appId` from step 4.
### Grant the necessary permissions
Azure relies on a Role-Based Access Control (RBAC) model to manage the access to resources at a fine-grained
level. There is a set of [pre-defined roles](https://docs.microsoft.com/en-us/azure/active-directory/role-based-access-built-in-roles)
which can be assigned to a service principal of an Azure AD application depending of your needs.
```
az role assignment create --assigner "SERVICE_PRINCIPAL_ID" --role "ROLE_NAME"
```
* Replace the `SERVICE_PRINCIPAL_ID` with the `appId` from previous step.
* Replace the `ROLE_NAME` with a role name of your choice.
It is also possible to define custom role definitions.
```
az role definition create --role-definition role-definition.json
```
* Check [custom roles](https://docs.microsoft.com/en-us/azure/active-directory/role-based-access-control-custom-roles) for more details regarding the content of `role-definition.json` file.
### Acquire Access Token
The common configuration used by all flows:
```Go
const activeDirectoryEndpoint = "https://login.microsoftonline.com/"
tenantID := "TENANT_ID"
oauthConfig, err := adal.NewOAuthConfig(activeDirectoryEndpoint, tenantID)
applicationID := "APPLICATION_ID"
callback := func(token adal.Token) error {
// This is called after the token is acquired
}
// The resource for which the token is acquired
resource := "https://management.core.windows.net/"
```
* Replace the `TENANT_ID` with your tenant ID.
* Replace the `APPLICATION_ID` with the value from previous section.
#### Client Credentials
```Go
applicationSecret := "APPLICATION_SECRET"
spt, err := adal.NewServicePrincipalToken(
oauthConfig,
appliationID,
applicationSecret,
resource,
callbacks...)
if err != nil {
return nil, err
}
// Acquire a new access token
err = spt.Refresh()
if (err == nil) {
token := spt.Token
}
```
* Replace the `APPLICATION_SECRET` with the `password` value from previous section.
#### Client Certificate
```Go
certificatePath := "./example-app.pfx"
certData, err := ioutil.ReadFile(certificatePath)
if err != nil {
return nil, fmt.Errorf("failed to read the certificate file (%s): %v", certificatePath, err)
}
// Get the certificate and private key from pfx file
certificate, rsaPrivateKey, err := decodePkcs12(certData, "")
if err != nil {
return nil, fmt.Errorf("failed to decode pkcs12 certificate while creating spt: %v", err)
}
spt, err := adal.NewServicePrincipalTokenFromCertificate(
oauthConfig,
applicationID,
certificate,
rsaPrivateKey,
resource,
callbacks...)
// Acquire a new access token
err = spt.Refresh()
if (err == nil) {
token := spt.Token
}
```
* Update the certificate path to point to the example-app.pfx file which was created in previous section.
#### Device Code
```Go
oauthClient := &http.Client{}
// Acquire the device code
deviceCode, err := adal.InitiateDeviceAuth(
oauthClient,
oauthConfig,
applicationID,
resource)
if err != nil {
return nil, fmt.Errorf("Failed to start device auth flow: %s", err)
}
// Display the authentication message
fmt.Println(*deviceCode.Message)
// Wait here until the user is authenticated
token, err := adal.WaitForUserCompletion(oauthClient, deviceCode)
if err != nil {
return nil, fmt.Errorf("Failed to finish device auth flow: %s", err)
}
spt, err := adal.NewServicePrincipalTokenFromManualToken(
oauthConfig,
applicationID,
resource,
*token,
callbacks...)
if (err == nil) {
token := spt.Token
}
```
### Command Line Tool
A command line tool is available in `cmd/adal.go` that can acquire a token for a given resource. It supports all flows mentioned above.
```
adal -h
Usage of ./adal:
-applicationId string
application id
-certificatePath string
path to pk12/PFC application certificate
-mode string
authentication mode (device, secret, cert, refresh) (default "device")
-resource string
resource for which the token is requested
-secret string
application secret
-tenantId string
tenant id
-tokenCachePath string
location of oath token cache (default "/home/cgc/.adal/accessToken.json")
```
Example acquire a token for `https://management.core.windows.net/` using device code flow:
```
adal -mode device \
-applicationId "APPLICATION_ID" \
-tenantId "TENANT_ID" \
-resource https://management.core.windows.net/
```

View File

@@ -1,298 +0,0 @@
package main
// Copyright 2017 Microsoft Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
import (
"flag"
"fmt"
"log"
"strings"
"crypto/rsa"
"crypto/x509"
"io/ioutil"
"net/http"
"os/user"
"github.com/Azure/go-autorest/autorest/adal"
"golang.org/x/crypto/pkcs12"
)
const (
deviceMode = "device"
clientSecretMode = "secret"
clientCertMode = "cert"
refreshMode = "refresh"
activeDirectoryEndpoint = "https://login.microsoftonline.com/"
)
type option struct {
name string
value string
}
var (
mode string
resource string
tenantID string
applicationID string
applicationSecret string
certificatePath string
tokenCachePath string
)
func checkMandatoryOptions(mode string, options ...option) {
for _, option := range options {
if strings.TrimSpace(option.value) == "" {
log.Fatalf("Authentication mode '%s' requires mandatory option '%s'.", mode, option.name)
}
}
}
func defaultTokenCachePath() string {
usr, err := user.Current()
if err != nil {
log.Fatal(err)
}
defaultTokenPath := usr.HomeDir + "/.adal/accessToken.json"
return defaultTokenPath
}
func init() {
flag.StringVar(&mode, "mode", "device", "authentication mode (device, secret, cert, refresh)")
flag.StringVar(&resource, "resource", "", "resource for which the token is requested")
flag.StringVar(&tenantID, "tenantId", "", "tenant id")
flag.StringVar(&applicationID, "applicationId", "", "application id")
flag.StringVar(&applicationSecret, "secret", "", "application secret")
flag.StringVar(&certificatePath, "certificatePath", "", "path to pk12/PFC application certificate")
flag.StringVar(&tokenCachePath, "tokenCachePath", defaultTokenCachePath(), "location of oath token cache")
flag.Parse()
switch mode = strings.TrimSpace(mode); mode {
case clientSecretMode:
checkMandatoryOptions(clientSecretMode,
option{name: "resource", value: resource},
option{name: "tenantId", value: tenantID},
option{name: "applicationId", value: applicationID},
option{name: "secret", value: applicationSecret},
)
case clientCertMode:
checkMandatoryOptions(clientCertMode,
option{name: "resource", value: resource},
option{name: "tenantId", value: tenantID},
option{name: "applicationId", value: applicationID},
option{name: "certificatePath", value: certificatePath},
)
case deviceMode:
checkMandatoryOptions(deviceMode,
option{name: "resource", value: resource},
option{name: "tenantId", value: tenantID},
option{name: "applicationId", value: applicationID},
)
case refreshMode:
checkMandatoryOptions(refreshMode,
option{name: "resource", value: resource},
option{name: "tenantId", value: tenantID},
option{name: "applicationId", value: applicationID},
)
default:
log.Fatalln("Authentication modes 'secret, 'cert', 'device' or 'refresh' are supported.")
}
}
func acquireTokenClientSecretFlow(oauthConfig adal.OAuthConfig,
appliationID string,
applicationSecret string,
resource string,
callbacks ...adal.TokenRefreshCallback) (*adal.ServicePrincipalToken, error) {
spt, err := adal.NewServicePrincipalToken(
oauthConfig,
appliationID,
applicationSecret,
resource,
callbacks...)
if err != nil {
return nil, err
}
return spt, spt.Refresh()
}
func decodePkcs12(pkcs []byte, password string) (*x509.Certificate, *rsa.PrivateKey, error) {
privateKey, certificate, err := pkcs12.Decode(pkcs, password)
if err != nil {
return nil, nil, err
}
rsaPrivateKey, isRsaKey := privateKey.(*rsa.PrivateKey)
if !isRsaKey {
return nil, nil, fmt.Errorf("PKCS#12 certificate must contain an RSA private key")
}
return certificate, rsaPrivateKey, nil
}
func acquireTokenClientCertFlow(oauthConfig adal.OAuthConfig,
applicationID string,
applicationCertPath string,
resource string,
callbacks ...adal.TokenRefreshCallback) (*adal.ServicePrincipalToken, error) {
certData, err := ioutil.ReadFile(certificatePath)
if err != nil {
return nil, fmt.Errorf("failed to read the certificate file (%s): %v", certificatePath, err)
}
certificate, rsaPrivateKey, err := decodePkcs12(certData, "")
if err != nil {
return nil, fmt.Errorf("failed to decode pkcs12 certificate while creating spt: %v", err)
}
spt, err := adal.NewServicePrincipalTokenFromCertificate(
oauthConfig,
applicationID,
certificate,
rsaPrivateKey,
resource,
callbacks...)
if err != nil {
return nil, err
}
return spt, spt.Refresh()
}
func acquireTokenDeviceCodeFlow(oauthConfig adal.OAuthConfig,
applicationID string,
resource string,
callbacks ...adal.TokenRefreshCallback) (*adal.ServicePrincipalToken, error) {
oauthClient := &http.Client{}
deviceCode, err := adal.InitiateDeviceAuth(
oauthClient,
oauthConfig,
applicationID,
resource)
if err != nil {
return nil, fmt.Errorf("Failed to start device auth flow: %s", err)
}
fmt.Println(*deviceCode.Message)
token, err := adal.WaitForUserCompletion(oauthClient, deviceCode)
if err != nil {
return nil, fmt.Errorf("Failed to finish device auth flow: %s", err)
}
spt, err := adal.NewServicePrincipalTokenFromManualToken(
oauthConfig,
applicationID,
resource,
*token,
callbacks...)
return spt, err
}
func refreshToken(oauthConfig adal.OAuthConfig,
applicationID string,
resource string,
tokenCachePath string,
callbacks ...adal.TokenRefreshCallback) (*adal.ServicePrincipalToken, error) {
token, err := adal.LoadToken(tokenCachePath)
if err != nil {
return nil, fmt.Errorf("failed to load token from cache: %v", err)
}
spt, err := adal.NewServicePrincipalTokenFromManualToken(
oauthConfig,
applicationID,
resource,
*token,
callbacks...)
if err != nil {
return nil, err
}
return spt, spt.Refresh()
}
func saveToken(spt adal.Token) error {
if tokenCachePath != "" {
err := adal.SaveToken(tokenCachePath, 0600, spt)
if err != nil {
return err
}
log.Printf("Acquired token was saved in '%s' file\n", tokenCachePath)
return nil
}
return fmt.Errorf("empty path for token cache")
}
func main() {
oauthConfig, err := adal.NewOAuthConfig(activeDirectoryEndpoint, tenantID)
if err != nil {
panic(err)
}
callback := func(token adal.Token) error {
return saveToken(token)
}
log.Printf("Authenticating with mode '%s'\n", mode)
switch mode {
case clientSecretMode:
_, err = acquireTokenClientSecretFlow(
*oauthConfig,
applicationID,
applicationSecret,
resource,
callback)
case clientCertMode:
_, err = acquireTokenClientCertFlow(
*oauthConfig,
applicationID,
certificatePath,
resource,
callback)
case deviceMode:
var spt *adal.ServicePrincipalToken
spt, err = acquireTokenDeviceCodeFlow(
*oauthConfig,
applicationID,
resource,
callback)
if err == nil {
err = saveToken(spt.Token)
}
case refreshMode:
_, err = refreshToken(
*oauthConfig,
applicationID,
resource,
tokenCachePath,
callback)
}
if err != nil {
log.Fatalf("Failed to acquire a token for resource %s. Error: %v", resource, err)
}
}

View File

@@ -32,8 +32,24 @@ type OAuthConfig struct {
DeviceCodeEndpoint url.URL
}
// IsZero returns true if the OAuthConfig object is zero-initialized.
func (oac OAuthConfig) IsZero() bool {
return oac == OAuthConfig{}
}
func validateStringParam(param, name string) error {
if len(param) == 0 {
return fmt.Errorf("parameter '" + name + "' cannot be empty")
}
return nil
}
// NewOAuthConfig returns an OAuthConfig with tenant specific urls
func NewOAuthConfig(activeDirectoryEndpoint, tenantID string) (*OAuthConfig, error) {
if err := validateStringParam(activeDirectoryEndpoint, "activeDirectoryEndpoint"); err != nil {
return nil, err
}
// it's legal for tenantID to be empty so don't validate it
const activeDirectoryEndpointTemplate = "%s/oauth2/%s?api-version=%s"
u, err := url.Parse(activeDirectoryEndpoint)
if err != nil {

View File

@@ -1,44 +0,0 @@
package adal
// Copyright 2017 Microsoft Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
import (
"testing"
)
func TestNewOAuthConfig(t *testing.T) {
const testActiveDirectoryEndpoint = "https://login.test.com"
const testTenantID = "tenant-id-test"
config, err := NewOAuthConfig(testActiveDirectoryEndpoint, testTenantID)
if err != nil {
t.Fatalf("autorest/adal: Unexpected error while creating oauth configuration for tenant: %v.", err)
}
expected := "https://login.test.com/tenant-id-test/oauth2/authorize?api-version=1.0"
if config.AuthorizeEndpoint.String() != expected {
t.Fatalf("autorest/adal: Incorrect authorize url for Tenant from Environment. expected(%s). actual(%v).", expected, config.AuthorizeEndpoint)
}
expected = "https://login.test.com/tenant-id-test/oauth2/token?api-version=1.0"
if config.TokenEndpoint.String() != expected {
t.Fatalf("autorest/adal: Incorrect authorize url for Tenant from Environment. expected(%s). actual(%v).", expected, config.TokenEndpoint)
}
expected = "https://login.test.com/tenant-id-test/oauth2/devicecode?api-version=1.0"
if config.DeviceCodeEndpoint.String() != expected {
t.Fatalf("autorest/adal Incorrect devicecode url for Tenant from Environment. expected(%s). actual(%v).", expected, config.DeviceCodeEndpoint)
}
}

View File

@@ -1,330 +0,0 @@
package adal
// Copyright 2017 Microsoft Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
import (
"encoding/json"
"fmt"
"net/http"
"strings"
"testing"
"github.com/Azure/go-autorest/autorest/mocks"
)
const (
TestResource = "SomeResource"
TestClientID = "SomeClientID"
TestTenantID = "SomeTenantID"
TestActiveDirectoryEndpoint = "https://login.test.com/"
)
var (
testOAuthConfig, _ = NewOAuthConfig(TestActiveDirectoryEndpoint, TestTenantID)
TestOAuthConfig = *testOAuthConfig
)
const MockDeviceCodeResponse = `
{
"device_code": "10000-40-1234567890",
"user_code": "ABCDEF",
"verification_url": "http://aka.ms/deviceauth",
"expires_in": "900",
"interval": "0"
}
`
const MockDeviceTokenResponse = `{
"access_token": "accessToken",
"refresh_token": "refreshToken",
"expires_in": "1000",
"expires_on": "2000",
"not_before": "3000",
"resource": "resource",
"token_type": "type"
}
`
func TestDeviceCodeIncludesResource(t *testing.T) {
sender := mocks.NewSender()
sender.AppendResponse(mocks.NewResponseWithContent(MockDeviceCodeResponse))
code, err := InitiateDeviceAuth(sender, TestOAuthConfig, TestClientID, TestResource)
if err != nil {
t.Fatalf("adal: unexpected error initiating device auth")
}
if code.Resource != TestResource {
t.Fatalf("adal: InitiateDeviceAuth failed to stash the resource in the DeviceCode struct")
}
}
func TestDeviceCodeReturnsErrorIfSendingFails(t *testing.T) {
sender := mocks.NewSender()
sender.SetError(fmt.Errorf("this is an error"))
_, err := InitiateDeviceAuth(sender, TestOAuthConfig, TestClientID, TestResource)
if err == nil || !strings.Contains(err.Error(), errCodeSendingFails) {
t.Fatalf("adal: failed to get correct error expected(%s) actual(%s)", errCodeSendingFails, err.Error())
}
}
func TestDeviceCodeReturnsErrorIfBadRequest(t *testing.T) {
sender := mocks.NewSender()
body := mocks.NewBody("doesn't matter")
sender.AppendResponse(mocks.NewResponseWithBodyAndStatus(body, http.StatusBadRequest, "Bad Request"))
_, err := InitiateDeviceAuth(sender, TestOAuthConfig, TestClientID, TestResource)
if err == nil || !strings.Contains(err.Error(), errCodeHandlingFails) {
t.Fatalf("adal: failed to get correct error expected(%s) actual(%s)", errCodeHandlingFails, err.Error())
}
if body.IsOpen() {
t.Fatalf("response body was left open!")
}
}
func TestDeviceCodeReturnsErrorIfCannotDeserializeDeviceCode(t *testing.T) {
gibberishJSON := strings.Replace(MockDeviceCodeResponse, "expires_in", "\":, :gibberish", -1)
sender := mocks.NewSender()
body := mocks.NewBody(gibberishJSON)
sender.AppendResponse(mocks.NewResponseWithBodyAndStatus(body, http.StatusOK, "OK"))
_, err := InitiateDeviceAuth(sender, TestOAuthConfig, TestClientID, TestResource)
if err == nil || !strings.Contains(err.Error(), errCodeHandlingFails) {
t.Fatalf("adal: failed to get correct error expected(%s) actual(%s)", errCodeHandlingFails, err.Error())
}
if body.IsOpen() {
t.Fatalf("response body was left open!")
}
}
func TestDeviceCodeReturnsErrorIfEmptyDeviceCode(t *testing.T) {
sender := mocks.NewSender()
body := mocks.NewBody("")
sender.AppendResponse(mocks.NewResponseWithBodyAndStatus(body, http.StatusOK, "OK"))
_, err := InitiateDeviceAuth(sender, TestOAuthConfig, TestClientID, TestResource)
if err != ErrDeviceCodeEmpty {
t.Fatalf("adal: failed to get correct error expected(%s) actual(%s)", ErrDeviceCodeEmpty, err.Error())
}
if body.IsOpen() {
t.Fatalf("response body was left open!")
}
}
func deviceCode() *DeviceCode {
var deviceCode DeviceCode
_ = json.Unmarshal([]byte(MockDeviceCodeResponse), &deviceCode)
deviceCode.Resource = TestResource
deviceCode.ClientID = TestClientID
return &deviceCode
}
func TestDeviceTokenReturns(t *testing.T) {
sender := mocks.NewSender()
body := mocks.NewBody(MockDeviceTokenResponse)
sender.AppendResponse(mocks.NewResponseWithBodyAndStatus(body, http.StatusOK, "OK"))
_, err := WaitForUserCompletion(sender, deviceCode())
if err != nil {
t.Fatalf("adal: got error unexpectedly")
}
if body.IsOpen() {
t.Fatalf("response body was left open!")
}
}
func TestDeviceTokenReturnsErrorIfSendingFails(t *testing.T) {
sender := mocks.NewSender()
sender.SetError(fmt.Errorf("this is an error"))
_, err := WaitForUserCompletion(sender, deviceCode())
if err == nil || !strings.Contains(err.Error(), errTokenSendingFails) {
t.Fatalf("adal: failed to get correct error expected(%s) actual(%s)", errTokenSendingFails, err.Error())
}
}
func TestDeviceTokenReturnsErrorIfServerError(t *testing.T) {
sender := mocks.NewSender()
body := mocks.NewBody("")
sender.AppendResponse(mocks.NewResponseWithBodyAndStatus(body, http.StatusInternalServerError, "Internal Server Error"))
_, err := WaitForUserCompletion(sender, deviceCode())
if err == nil || !strings.Contains(err.Error(), errTokenHandlingFails) {
t.Fatalf("adal: failed to get correct error expected(%s) actual(%s)", errTokenHandlingFails, err.Error())
}
if body.IsOpen() {
t.Fatalf("response body was left open!")
}
}
func TestDeviceTokenReturnsErrorIfCannotDeserializeDeviceToken(t *testing.T) {
gibberishJSON := strings.Replace(MockDeviceTokenResponse, "expires_in", ";:\"gibberish", -1)
sender := mocks.NewSender()
body := mocks.NewBody(gibberishJSON)
sender.AppendResponse(mocks.NewResponseWithBodyAndStatus(body, http.StatusOK, "OK"))
_, err := WaitForUserCompletion(sender, deviceCode())
if err == nil || !strings.Contains(err.Error(), errTokenHandlingFails) {
t.Fatalf("adal: failed to get correct error expected(%s) actual(%s)", errTokenHandlingFails, err.Error())
}
if body.IsOpen() {
t.Fatalf("response body was left open!")
}
}
func errorDeviceTokenResponse(message string) string {
return `{ "error": "` + message + `" }`
}
func TestDeviceTokenReturnsErrorIfAuthorizationPending(t *testing.T) {
sender := mocks.NewSender()
body := mocks.NewBody(errorDeviceTokenResponse("authorization_pending"))
sender.AppendResponse(mocks.NewResponseWithBodyAndStatus(body, http.StatusBadRequest, "Bad Request"))
_, err := CheckForUserCompletion(sender, deviceCode())
if err != ErrDeviceAuthorizationPending {
t.Fatalf("!!!")
}
if body.IsOpen() {
t.Fatalf("response body was left open!")
}
}
func TestDeviceTokenReturnsErrorIfSlowDown(t *testing.T) {
sender := mocks.NewSender()
body := mocks.NewBody(errorDeviceTokenResponse("slow_down"))
sender.AppendResponse(mocks.NewResponseWithBodyAndStatus(body, http.StatusBadRequest, "Bad Request"))
_, err := CheckForUserCompletion(sender, deviceCode())
if err != ErrDeviceSlowDown {
t.Fatalf("!!!")
}
if body.IsOpen() {
t.Fatalf("response body was left open!")
}
}
type deviceTokenSender struct {
errorString string
attempts int
}
func newDeviceTokenSender(deviceErrorString string) *deviceTokenSender {
return &deviceTokenSender{errorString: deviceErrorString, attempts: 0}
}
func (s *deviceTokenSender) Do(req *http.Request) (*http.Response, error) {
var resp *http.Response
if s.attempts < 1 {
s.attempts++
resp = mocks.NewResponseWithContent(errorDeviceTokenResponse(s.errorString))
} else {
resp = mocks.NewResponseWithContent(MockDeviceTokenResponse)
}
return resp, nil
}
// since the above only exercise CheckForUserCompletion, we repeat the test here,
// but with the intent of showing that WaitForUserCompletion loops properly.
func TestDeviceTokenSucceedsWithIntermediateAuthPending(t *testing.T) {
sender := newDeviceTokenSender("authorization_pending")
_, err := WaitForUserCompletion(sender, deviceCode())
if err != nil {
t.Fatalf("unexpected error occurred")
}
}
// same as above but with SlowDown now
func TestDeviceTokenSucceedsWithIntermediateSlowDown(t *testing.T) {
sender := newDeviceTokenSender("slow_down")
_, err := WaitForUserCompletion(sender, deviceCode())
if err != nil {
t.Fatalf("unexpected error occurred")
}
}
func TestDeviceTokenReturnsErrorIfAccessDenied(t *testing.T) {
sender := mocks.NewSender()
body := mocks.NewBody(errorDeviceTokenResponse("access_denied"))
sender.AppendResponse(mocks.NewResponseWithBodyAndStatus(body, http.StatusBadRequest, "Bad Request"))
_, err := WaitForUserCompletion(sender, deviceCode())
if err != ErrDeviceAccessDenied {
t.Fatalf("adal: got wrong error expected(%s) actual(%s)", ErrDeviceAccessDenied.Error(), err.Error())
}
if body.IsOpen() {
t.Fatalf("response body was left open!")
}
}
func TestDeviceTokenReturnsErrorIfCodeExpired(t *testing.T) {
sender := mocks.NewSender()
body := mocks.NewBody(errorDeviceTokenResponse("code_expired"))
sender.AppendResponse(mocks.NewResponseWithBodyAndStatus(body, http.StatusBadRequest, "Bad Request"))
_, err := WaitForUserCompletion(sender, deviceCode())
if err != ErrDeviceCodeExpired {
t.Fatalf("adal: got wrong error expected(%s) actual(%s)", ErrDeviceCodeExpired.Error(), err.Error())
}
if body.IsOpen() {
t.Fatalf("response body was left open!")
}
}
func TestDeviceTokenReturnsErrorForUnknownError(t *testing.T) {
sender := mocks.NewSender()
body := mocks.NewBody(errorDeviceTokenResponse("unknown_error"))
sender.AppendResponse(mocks.NewResponseWithBodyAndStatus(body, http.StatusBadRequest, "Bad Request"))
_, err := WaitForUserCompletion(sender, deviceCode())
if err == nil {
t.Fatalf("failed to get error")
}
if err != ErrDeviceGeneric {
t.Fatalf("adal: got wrong error expected(%s) actual(%s)", ErrDeviceGeneric.Error(), err.Error())
}
if body.IsOpen() {
t.Fatalf("response body was left open!")
}
}
func TestDeviceTokenReturnsErrorIfTokenEmptyAndStatusOK(t *testing.T) {
sender := mocks.NewSender()
body := mocks.NewBody("")
sender.AppendResponse(mocks.NewResponseWithBodyAndStatus(body, http.StatusOK, "OK"))
_, err := WaitForUserCompletion(sender, deviceCode())
if err != ErrOAuthTokenEmpty {
t.Fatalf("adal: got wrong error expected(%s) actual(%s)", ErrOAuthTokenEmpty.Error(), err.Error())
}
if body.IsOpen() {
t.Fatalf("response body was left open!")
}
}

View File

@@ -1,20 +0,0 @@
// +build !windows
package adal
// Copyright 2017 Microsoft Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// msiPath is the path to the MSI Extension settings file (to discover the endpoint)
var msiPath = "/var/lib/waagent/ManagedIdentity-Settings"

View File

@@ -1,25 +0,0 @@
// +build windows
package adal
// Copyright 2017 Microsoft Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
import (
"os"
"strings"
)
// msiPath is the path to the MSI Extension settings file (to discover the endpoint)
var msiPath = strings.Join([]string{os.Getenv("SystemDrive"), "WindowsAzure/Config/ManagedIdentity-Settings"}, "/")

View File

@@ -1,171 +0,0 @@
package adal
// Copyright 2017 Microsoft Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
import (
"encoding/json"
"io/ioutil"
"os"
"path"
"reflect"
"runtime"
"strings"
"testing"
)
const MockTokenJSON string = `{
"access_token": "accessToken",
"refresh_token": "refreshToken",
"expires_in": "1000",
"expires_on": "2000",
"not_before": "3000",
"resource": "resource",
"token_type": "type"
}`
var TestToken = Token{
AccessToken: "accessToken",
RefreshToken: "refreshToken",
ExpiresIn: "1000",
ExpiresOn: "2000",
NotBefore: "3000",
Resource: "resource",
Type: "type",
}
func writeTestTokenFile(t *testing.T, suffix string, contents string) *os.File {
f, err := ioutil.TempFile(os.TempDir(), suffix)
if err != nil {
t.Fatalf("azure: unexpected error when creating temp file: %v", err)
}
defer f.Close()
_, err = f.Write([]byte(contents))
if err != nil {
t.Fatalf("azure: unexpected error when writing temp test file: %v", err)
}
return f
}
func TestLoadToken(t *testing.T) {
f := writeTestTokenFile(t, "testloadtoken", MockTokenJSON)
defer os.Remove(f.Name())
expectedToken := TestToken
actualToken, err := LoadToken(f.Name())
if err != nil {
t.Fatalf("azure: unexpected error loading token from file: %v", err)
}
if *actualToken != expectedToken {
t.Fatalf("azure: failed to decode properly expected(%v) actual(%v)", expectedToken, *actualToken)
}
// test that LoadToken closes the file properly
err = SaveToken(f.Name(), 0600, *actualToken)
if err != nil {
t.Fatalf("azure: could not save token after LoadToken: %v", err)
}
}
func TestLoadTokenFailsBadPath(t *testing.T) {
_, err := LoadToken("/tmp/this_file_should_never_exist_really")
expectedSubstring := "failed to open file"
if err == nil || !strings.Contains(err.Error(), expectedSubstring) {
t.Fatalf("azure: failed to get correct error expected(%s) actual(%s)", expectedSubstring, err.Error())
}
}
func TestLoadTokenFailsBadJson(t *testing.T) {
gibberishJSON := strings.Replace(MockTokenJSON, "expires_on", ";:\"gibberish", -1)
f := writeTestTokenFile(t, "testloadtokenfailsbadjson", gibberishJSON)
defer os.Remove(f.Name())
_, err := LoadToken(f.Name())
expectedSubstring := "failed to decode contents of file"
if err == nil || !strings.Contains(err.Error(), expectedSubstring) {
t.Fatalf("azure: failed to get correct error expected(%s) actual(%s)", expectedSubstring, err.Error())
}
}
func token() *Token {
var token Token
json.Unmarshal([]byte(MockTokenJSON), &token)
return &token
}
func TestSaveToken(t *testing.T) {
f, err := ioutil.TempFile("", "testloadtoken")
if err != nil {
t.Fatalf("azure: unexpected error when creating temp file: %v", err)
}
defer os.Remove(f.Name())
f.Close()
mode := os.ModePerm & 0642
err = SaveToken(f.Name(), mode, *token())
if err != nil {
t.Fatalf("azure: unexpected error saving token to file: %v", err)
}
fi, err := os.Stat(f.Name()) // open a new stat as held ones are not fresh
if err != nil {
t.Fatalf("azure: stat failed: %v", err)
}
if runtime.GOOS != "windows" { // permissions don't work on Windows
if perm := fi.Mode().Perm(); perm != mode {
t.Fatalf("azure: wrong file perm. got:%s; expected:%s file :%s", perm, mode, f.Name())
}
}
var actualToken Token
var expectedToken Token
json.Unmarshal([]byte(MockTokenJSON), expectedToken)
contents, err := ioutil.ReadFile(f.Name())
if err != nil {
t.Fatal("!!")
}
json.Unmarshal(contents, actualToken)
if !reflect.DeepEqual(actualToken, expectedToken) {
t.Fatal("azure: token was not serialized correctly")
}
}
func TestSaveTokenFailsNoPermission(t *testing.T) {
pathWhereWeShouldntHavePermission := "/usr/thiswontwork/atall"
if runtime.GOOS == "windows" {
pathWhereWeShouldntHavePermission = path.Join(os.Getenv("windir"), "system32\\mytokendir\\mytoken")
}
err := SaveToken(pathWhereWeShouldntHavePermission, 0644, *token())
expectedSubstring := "failed to create directory"
if err == nil || !strings.Contains(err.Error(), expectedSubstring) {
t.Fatalf("azure: failed to get correct error expected(%s) actual(%v)", expectedSubstring, err)
}
}
func TestSaveTokenFailsCantCreate(t *testing.T) {
tokenPath := "/thiswontwork"
if runtime.GOOS == "windows" {
tokenPath = path.Join(os.Getenv("windir"), "system32")
}
err := SaveToken(tokenPath, 0644, *token())
expectedSubstring := "failed to create the temp file to write the token"
if err == nil || !strings.Contains(err.Error(), expectedSubstring) {
t.Fatalf("azure: failed to get correct error expected(%s) actual(%v)", expectedSubstring, err)
}
}

View File

@@ -15,6 +15,7 @@ package adal
// limitations under the License.
import (
"context"
"crypto/rand"
"crypto/rsa"
"crypto/sha1"
@@ -23,10 +24,12 @@ import (
"encoding/json"
"fmt"
"io/ioutil"
"net"
"net/http"
"net/url"
"strconv"
"strings"
"sync"
"time"
"github.com/Azure/go-autorest/autorest/date"
@@ -42,11 +45,20 @@ const (
// OAuthGrantTypeClientCredentials is the "grant_type" identifier used in credential flows
OAuthGrantTypeClientCredentials = "client_credentials"
// OAuthGrantTypeUserPass is the "grant_type" identifier used in username and password auth flows
OAuthGrantTypeUserPass = "password"
// OAuthGrantTypeRefreshToken is the "grant_type" identifier used in refresh token flows
OAuthGrantTypeRefreshToken = "refresh_token"
// OAuthGrantTypeAuthorizationCode is the "grant_type" identifier used in authorization code flows
OAuthGrantTypeAuthorizationCode = "authorization_code"
// metadataHeader is the header required by MSI extension
metadataHeader = "Metadata"
// msiEndpoint is the well known endpoint for getting MSI authentications tokens
msiEndpoint = "http://169.254.169.254/metadata/identity/oauth2/token"
)
// OAuthTokenProvider is an interface which should be implemented by an access token retriever
@@ -54,6 +66,12 @@ type OAuthTokenProvider interface {
OAuthToken() string
}
// TokenRefreshError is an interface used by errors returned during token refresh.
type TokenRefreshError interface {
error
Response() *http.Response
}
// Refresher is an interface for token refresh functionality
type Refresher interface {
Refresh() error
@@ -61,6 +79,13 @@ type Refresher interface {
EnsureFresh() error
}
// RefresherWithContext is an interface for token refresh functionality
type RefresherWithContext interface {
RefreshWithContext(ctx context.Context) error
RefreshExchangeWithContext(ctx context.Context, resource string) error
EnsureFreshWithContext(ctx context.Context) error
}
// TokenRefreshCallback is the type representing callbacks that will be called after
// a successful token refresh
type TokenRefreshCallback func(Token) error
@@ -78,6 +103,11 @@ type Token struct {
Type string `json:"token_type"`
}
// IsZero returns true if the token object is zero-initialized.
func (t Token) IsZero() bool {
return t == Token{}
}
// Expires returns the time.Time when the Token expires.
func (t Token) Expires() time.Time {
s, err := strconv.Atoi(t.ExpiresOn)
@@ -145,6 +175,34 @@ type ServicePrincipalCertificateSecret struct {
type ServicePrincipalMSISecret struct {
}
// ServicePrincipalUsernamePasswordSecret implements ServicePrincipalSecret for username and password auth.
type ServicePrincipalUsernamePasswordSecret struct {
Username string
Password string
}
// ServicePrincipalAuthorizationCodeSecret implements ServicePrincipalSecret for authorization code auth.
type ServicePrincipalAuthorizationCodeSecret struct {
ClientSecret string
AuthorizationCode string
RedirectURI string
}
// SetAuthenticationValues is a method of the interface ServicePrincipalSecret.
func (secret *ServicePrincipalAuthorizationCodeSecret) SetAuthenticationValues(spt *ServicePrincipalToken, v *url.Values) error {
v.Set("code", secret.AuthorizationCode)
v.Set("client_secret", secret.ClientSecret)
v.Set("redirect_uri", secret.RedirectURI)
return nil
}
// SetAuthenticationValues is a method of the interface ServicePrincipalSecret.
func (secret *ServicePrincipalUsernamePasswordSecret) SetAuthenticationValues(spt *ServicePrincipalToken, v *url.Values) error {
v.Set("username", secret.Username)
v.Set("password", secret.Password)
return nil
}
// SetAuthenticationValues is a method of the interface ServicePrincipalSecret.
func (msiSecret *ServicePrincipalMSISecret) SetAuthenticationValues(spt *ServicePrincipalToken, v *url.Values) error {
return nil
@@ -197,27 +255,47 @@ func (secret *ServicePrincipalCertificateSecret) SetAuthenticationValues(spt *Se
// ServicePrincipalToken encapsulates a Token created for a Service Principal.
type ServicePrincipalToken struct {
Token
token Token
secret ServicePrincipalSecret
oauthConfig OAuthConfig
clientID string
resource string
autoRefresh bool
refreshLock *sync.RWMutex
refreshWithin time.Duration
sender Sender
refreshCallbacks []TokenRefreshCallback
}
func validateOAuthConfig(oac OAuthConfig) error {
if oac.IsZero() {
return fmt.Errorf("parameter 'oauthConfig' cannot be zero-initialized")
}
return nil
}
// NewServicePrincipalTokenWithSecret create a ServicePrincipalToken using the supplied ServicePrincipalSecret implementation.
func NewServicePrincipalTokenWithSecret(oauthConfig OAuthConfig, id string, resource string, secret ServicePrincipalSecret, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) {
if err := validateOAuthConfig(oauthConfig); err != nil {
return nil, err
}
if err := validateStringParam(id, "id"); err != nil {
return nil, err
}
if err := validateStringParam(resource, "resource"); err != nil {
return nil, err
}
if secret == nil {
return nil, fmt.Errorf("parameter 'secret' cannot be nil")
}
spt := &ServicePrincipalToken{
oauthConfig: oauthConfig,
secret: secret,
clientID: id,
resource: resource,
autoRefresh: true,
refreshLock: &sync.RWMutex{},
refreshWithin: defaultRefresh,
sender: &http.Client{},
refreshCallbacks: callbacks,
@@ -227,6 +305,18 @@ func NewServicePrincipalTokenWithSecret(oauthConfig OAuthConfig, id string, reso
// NewServicePrincipalTokenFromManualToken creates a ServicePrincipalToken using the supplied token
func NewServicePrincipalTokenFromManualToken(oauthConfig OAuthConfig, clientID string, resource string, token Token, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) {
if err := validateOAuthConfig(oauthConfig); err != nil {
return nil, err
}
if err := validateStringParam(clientID, "clientID"); err != nil {
return nil, err
}
if err := validateStringParam(resource, "resource"); err != nil {
return nil, err
}
if token.IsZero() {
return nil, fmt.Errorf("parameter 'token' cannot be zero-initialized")
}
spt, err := NewServicePrincipalTokenWithSecret(
oauthConfig,
clientID,
@@ -237,7 +327,7 @@ func NewServicePrincipalTokenFromManualToken(oauthConfig OAuthConfig, clientID s
return nil, err
}
spt.Token = token
spt.token = token
return spt, nil
}
@@ -245,6 +335,18 @@ func NewServicePrincipalTokenFromManualToken(oauthConfig OAuthConfig, clientID s
// NewServicePrincipalToken creates a ServicePrincipalToken from the supplied Service Principal
// credentials scoped to the named resource.
func NewServicePrincipalToken(oauthConfig OAuthConfig, clientID string, secret string, resource string, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) {
if err := validateOAuthConfig(oauthConfig); err != nil {
return nil, err
}
if err := validateStringParam(clientID, "clientID"); err != nil {
return nil, err
}
if err := validateStringParam(secret, "secret"); err != nil {
return nil, err
}
if err := validateStringParam(resource, "resource"); err != nil {
return nil, err
}
return NewServicePrincipalTokenWithSecret(
oauthConfig,
clientID,
@@ -256,8 +358,23 @@ func NewServicePrincipalToken(oauthConfig OAuthConfig, clientID string, secret s
)
}
// NewServicePrincipalTokenFromCertificate create a ServicePrincipalToken from the supplied pkcs12 bytes.
// NewServicePrincipalTokenFromCertificate creates a ServicePrincipalToken from the supplied pkcs12 bytes.
func NewServicePrincipalTokenFromCertificate(oauthConfig OAuthConfig, clientID string, certificate *x509.Certificate, privateKey *rsa.PrivateKey, resource string, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) {
if err := validateOAuthConfig(oauthConfig); err != nil {
return nil, err
}
if err := validateStringParam(clientID, "clientID"); err != nil {
return nil, err
}
if err := validateStringParam(resource, "resource"); err != nil {
return nil, err
}
if certificate == nil {
return nil, fmt.Errorf("parameter 'certificate' cannot be nil")
}
if privateKey == nil {
return nil, fmt.Errorf("parameter 'privateKey' cannot be nil")
}
return NewServicePrincipalTokenWithSecret(
oauthConfig,
clientID,
@@ -270,59 +387,169 @@ func NewServicePrincipalTokenFromCertificate(oauthConfig OAuthConfig, clientID s
)
}
// GetMSIVMEndpoint gets the MSI endpoint on Virtual Machines.
func GetMSIVMEndpoint() (string, error) {
return getMSIVMEndpoint(msiPath)
// NewServicePrincipalTokenFromUsernamePassword creates a ServicePrincipalToken from the username and password.
func NewServicePrincipalTokenFromUsernamePassword(oauthConfig OAuthConfig, clientID string, username string, password string, resource string, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) {
if err := validateOAuthConfig(oauthConfig); err != nil {
return nil, err
}
if err := validateStringParam(clientID, "clientID"); err != nil {
return nil, err
}
if err := validateStringParam(username, "username"); err != nil {
return nil, err
}
if err := validateStringParam(password, "password"); err != nil {
return nil, err
}
if err := validateStringParam(resource, "resource"); err != nil {
return nil, err
}
return NewServicePrincipalTokenWithSecret(
oauthConfig,
clientID,
resource,
&ServicePrincipalUsernamePasswordSecret{
Username: username,
Password: password,
},
callbacks...,
)
}
func getMSIVMEndpoint(path string) (string, error) {
// Read MSI settings
bytes, err := ioutil.ReadFile(path)
if err != nil {
return "", err
// NewServicePrincipalTokenFromAuthorizationCode creates a ServicePrincipalToken from the
func NewServicePrincipalTokenFromAuthorizationCode(oauthConfig OAuthConfig, clientID string, clientSecret string, authorizationCode string, redirectURI string, resource string, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) {
if err := validateOAuthConfig(oauthConfig); err != nil {
return nil, err
}
msiSettings := struct {
URL string `json:"url"`
}{}
err = json.Unmarshal(bytes, &msiSettings)
if err != nil {
return "", err
if err := validateStringParam(clientID, "clientID"); err != nil {
return nil, err
}
if err := validateStringParam(clientSecret, "clientSecret"); err != nil {
return nil, err
}
if err := validateStringParam(authorizationCode, "authorizationCode"); err != nil {
return nil, err
}
if err := validateStringParam(redirectURI, "redirectURI"); err != nil {
return nil, err
}
if err := validateStringParam(resource, "resource"); err != nil {
return nil, err
}
return msiSettings.URL, nil
return NewServicePrincipalTokenWithSecret(
oauthConfig,
clientID,
resource,
&ServicePrincipalAuthorizationCodeSecret{
ClientSecret: clientSecret,
AuthorizationCode: authorizationCode,
RedirectURI: redirectURI,
},
callbacks...,
)
}
// GetMSIVMEndpoint gets the MSI endpoint on Virtual Machines.
func GetMSIVMEndpoint() (string, error) {
return msiEndpoint, nil
}
// NewServicePrincipalTokenFromMSI creates a ServicePrincipalToken via the MSI VM Extension.
// It will use the system assigned identity when creating the token.
func NewServicePrincipalTokenFromMSI(msiEndpoint, resource string, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) {
return newServicePrincipalTokenFromMSI(msiEndpoint, resource, nil, callbacks...)
}
// NewServicePrincipalTokenFromMSIWithUserAssignedID creates a ServicePrincipalToken via the MSI VM Extension.
// It will use the specified user assigned identity when creating the token.
func NewServicePrincipalTokenFromMSIWithUserAssignedID(msiEndpoint, resource string, userAssignedID string, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) {
return newServicePrincipalTokenFromMSI(msiEndpoint, resource, &userAssignedID, callbacks...)
}
func newServicePrincipalTokenFromMSI(msiEndpoint, resource string, userAssignedID *string, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) {
if err := validateStringParam(msiEndpoint, "msiEndpoint"); err != nil {
return nil, err
}
if err := validateStringParam(resource, "resource"); err != nil {
return nil, err
}
if userAssignedID != nil {
if err := validateStringParam(*userAssignedID, "userAssignedID"); err != nil {
return nil, err
}
}
// We set the oauth config token endpoint to be MSI's endpoint
msiEndpointURL, err := url.Parse(msiEndpoint)
if err != nil {
return nil, err
}
oauthConfig, err := NewOAuthConfig(msiEndpointURL.String(), "")
if err != nil {
return nil, err
v := url.Values{}
v.Set("resource", resource)
v.Set("api-version", "2018-02-01")
if userAssignedID != nil {
v.Set("client_id", *userAssignedID)
}
msiEndpointURL.RawQuery = v.Encode()
spt := &ServicePrincipalToken{
oauthConfig: *oauthConfig,
oauthConfig: OAuthConfig{
TokenEndpoint: *msiEndpointURL,
},
secret: &ServicePrincipalMSISecret{},
resource: resource,
autoRefresh: true,
refreshLock: &sync.RWMutex{},
refreshWithin: defaultRefresh,
sender: &http.Client{},
refreshCallbacks: callbacks,
}
if userAssignedID != nil {
spt.clientID = *userAssignedID
}
return spt, nil
}
// internal type that implements TokenRefreshError
type tokenRefreshError struct {
message string
resp *http.Response
}
// Error implements the error interface which is part of the TokenRefreshError interface.
func (tre tokenRefreshError) Error() string {
return tre.message
}
// Response implements the TokenRefreshError interface, it returns the raw HTTP response from the refresh operation.
func (tre tokenRefreshError) Response() *http.Response {
return tre.resp
}
func newTokenRefreshError(message string, resp *http.Response) TokenRefreshError {
return tokenRefreshError{message: message, resp: resp}
}
// EnsureFresh will refresh the token if it will expire within the refresh window (as set by
// RefreshWithin) and autoRefresh flag is on.
// RefreshWithin) and autoRefresh flag is on. This method is safe for concurrent use.
func (spt *ServicePrincipalToken) EnsureFresh() error {
if spt.autoRefresh && spt.WillExpireIn(spt.refreshWithin) {
return spt.Refresh()
return spt.EnsureFreshWithContext(context.Background())
}
// EnsureFreshWithContext will refresh the token if it will expire within the refresh window (as set by
// RefreshWithin) and autoRefresh flag is on. This method is safe for concurrent use.
func (spt *ServicePrincipalToken) EnsureFreshWithContext(ctx context.Context) error {
if spt.autoRefresh && spt.token.WillExpireIn(spt.refreshWithin) {
// take the write lock then check to see if the token was already refreshed
spt.refreshLock.Lock()
defer spt.refreshLock.Unlock()
if spt.token.WillExpireIn(spt.refreshWithin) {
return spt.refreshInternal(ctx, spt.resource)
}
}
return nil
}
@@ -331,7 +558,7 @@ func (spt *ServicePrincipalToken) EnsureFresh() error {
func (spt *ServicePrincipalToken) InvokeRefreshCallbacks(token Token) error {
if spt.refreshCallbacks != nil {
for _, callback := range spt.refreshCallbacks {
err := callback(spt.Token)
err := callback(spt.token)
if err != nil {
return fmt.Errorf("adal: TokenRefreshCallback handler failed. Error = '%v'", err)
}
@@ -341,46 +568,94 @@ func (spt *ServicePrincipalToken) InvokeRefreshCallbacks(token Token) error {
}
// Refresh obtains a fresh token for the Service Principal.
// This method is not safe for concurrent use and should be syncrhonized.
func (spt *ServicePrincipalToken) Refresh() error {
return spt.refreshInternal(spt.resource)
return spt.RefreshWithContext(context.Background())
}
// RefreshWithContext obtains a fresh token for the Service Principal.
// This method is not safe for concurrent use and should be syncrhonized.
func (spt *ServicePrincipalToken) RefreshWithContext(ctx context.Context) error {
spt.refreshLock.Lock()
defer spt.refreshLock.Unlock()
return spt.refreshInternal(ctx, spt.resource)
}
// RefreshExchange refreshes the token, but for a different resource.
// This method is not safe for concurrent use and should be syncrhonized.
func (spt *ServicePrincipalToken) RefreshExchange(resource string) error {
return spt.refreshInternal(resource)
return spt.RefreshExchangeWithContext(context.Background(), resource)
}
func (spt *ServicePrincipalToken) refreshInternal(resource string) error {
v := url.Values{}
v.Set("client_id", spt.clientID)
v.Set("resource", resource)
// RefreshExchangeWithContext refreshes the token, but for a different resource.
// This method is not safe for concurrent use and should be syncrhonized.
func (spt *ServicePrincipalToken) RefreshExchangeWithContext(ctx context.Context, resource string) error {
spt.refreshLock.Lock()
defer spt.refreshLock.Unlock()
return spt.refreshInternal(ctx, resource)
}
if spt.RefreshToken != "" {
v.Set("grant_type", OAuthGrantTypeRefreshToken)
v.Set("refresh_token", spt.RefreshToken)
} else {
v.Set("grant_type", OAuthGrantTypeClientCredentials)
err := spt.secret.SetAuthenticationValues(spt, &v)
if err != nil {
return err
}
func (spt *ServicePrincipalToken) getGrantType() string {
switch spt.secret.(type) {
case *ServicePrincipalUsernamePasswordSecret:
return OAuthGrantTypeUserPass
case *ServicePrincipalAuthorizationCodeSecret:
return OAuthGrantTypeAuthorizationCode
default:
return OAuthGrantTypeClientCredentials
}
}
s := v.Encode()
body := ioutil.NopCloser(strings.NewReader(s))
req, err := http.NewRequest(http.MethodPost, spt.oauthConfig.TokenEndpoint.String(), body)
func isIMDS(u url.URL) bool {
imds, err := url.Parse(msiEndpoint)
if err != nil {
return false
}
return u.Host == imds.Host && u.Path == imds.Path
}
func (spt *ServicePrincipalToken) refreshInternal(ctx context.Context, resource string) error {
req, err := http.NewRequest(http.MethodPost, spt.oauthConfig.TokenEndpoint.String(), nil)
if err != nil {
return fmt.Errorf("adal: Failed to build the refresh request. Error = '%v'", err)
}
req = req.WithContext(ctx)
if !isIMDS(spt.oauthConfig.TokenEndpoint) {
v := url.Values{}
v.Set("client_id", spt.clientID)
v.Set("resource", resource)
if spt.token.RefreshToken != "" {
v.Set("grant_type", OAuthGrantTypeRefreshToken)
v.Set("refresh_token", spt.token.RefreshToken)
} else {
v.Set("grant_type", spt.getGrantType())
err := spt.secret.SetAuthenticationValues(spt, &v)
if err != nil {
return err
}
}
s := v.Encode()
body := ioutil.NopCloser(strings.NewReader(s))
req.ContentLength = int64(len(s))
req.Header.Set(contentType, mimeTypeFormPost)
req.Body = body
}
req.ContentLength = int64(len(s))
req.Header.Set(contentType, mimeTypeFormPost)
if _, ok := spt.secret.(*ServicePrincipalMSISecret); ok {
req.Method = http.MethodGet
req.Header.Set(metadataHeader, "true")
}
resp, err := spt.sender.Do(req)
var resp *http.Response
if isIMDS(spt.oauthConfig.TokenEndpoint) {
resp, err = retry(spt.sender, req)
} else {
resp, err = spt.sender.Do(req)
}
if err != nil {
return fmt.Errorf("adal: Failed to execute the refresh request. Error = '%v'", err)
return newTokenRefreshError(fmt.Sprintf("adal: Failed to execute the refresh request. Error = '%v'", err), nil)
}
defer resp.Body.Close()
@@ -388,11 +663,15 @@ func (spt *ServicePrincipalToken) refreshInternal(resource string) error {
if resp.StatusCode != http.StatusOK {
if err != nil {
return fmt.Errorf("adal: Refresh request failed. Status Code = '%d'. Failed reading response body", resp.StatusCode)
return newTokenRefreshError(fmt.Sprintf("adal: Refresh request failed. Status Code = '%d'. Failed reading response body: %v", resp.StatusCode, err), resp)
}
return fmt.Errorf("adal: Refresh request failed. Status Code = '%d'. Response body: %s", resp.StatusCode, string(rb))
return newTokenRefreshError(fmt.Sprintf("adal: Refresh request failed. Status Code = '%d'. Response body: %s", resp.StatusCode, string(rb)), resp)
}
// for the following error cases don't return a TokenRefreshError. the operation succeeded
// but some transient failure happened during deserialization. by returning a generic error
// the retry logic will kick in (we don't retry on TokenRefreshError).
if err != nil {
return fmt.Errorf("adal: Failed to read a new service principal token during refresh. Error = '%v'", err)
}
@@ -405,11 +684,86 @@ func (spt *ServicePrincipalToken) refreshInternal(resource string) error {
return fmt.Errorf("adal: Failed to unmarshal the service principal token during refresh. Error = '%v' JSON = '%s'", err, string(rb))
}
spt.Token = token
spt.token = token
return spt.InvokeRefreshCallbacks(token)
}
func retry(sender Sender, req *http.Request) (resp *http.Response, err error) {
retries := []int{
http.StatusRequestTimeout, // 408
http.StatusTooManyRequests, // 429
http.StatusInternalServerError, // 500
http.StatusBadGateway, // 502
http.StatusServiceUnavailable, // 503
http.StatusGatewayTimeout, // 504
}
// Extra retry status codes requered
retries = append(retries, http.StatusNotFound,
// all remaining 5xx
http.StatusNotImplemented,
http.StatusHTTPVersionNotSupported,
http.StatusVariantAlsoNegotiates,
http.StatusInsufficientStorage,
http.StatusLoopDetected,
http.StatusNotExtended,
http.StatusNetworkAuthenticationRequired)
attempt := 0
maxAttempts := 5
for attempt < maxAttempts {
resp, err = sender.Do(req)
// retry on temporary network errors, e.g. transient network failures.
if (err != nil && !isTemporaryNetworkError(err)) || resp.StatusCode == http.StatusOK || !containsInt(retries, resp.StatusCode) {
return
}
if !delay(resp, req.Context().Done()) {
select {
case <-time.After(time.Second):
attempt++
case <-req.Context().Done():
err = req.Context().Err()
return
}
}
}
return
}
func isTemporaryNetworkError(err error) bool {
if netErr, ok := err.(net.Error); ok && netErr.Temporary() {
return true
}
return false
}
func containsInt(ints []int, n int) bool {
for _, i := range ints {
if i == n {
return true
}
}
return false
}
func delay(resp *http.Response, cancel <-chan struct{}) bool {
if resp == nil {
return false
}
retryAfter, _ := strconv.Atoi(resp.Header.Get("Retry-After"))
if resp.StatusCode == http.StatusTooManyRequests && retryAfter > 0 {
select {
case <-time.After(time.Duration(retryAfter) * time.Second):
return true
case <-cancel:
return false
}
}
return false
}
// SetAutoRefresh enables or disables automatic refreshing of stale tokens.
func (spt *ServicePrincipalToken) SetAutoRefresh(autoRefresh bool) {
spt.autoRefresh = autoRefresh
@@ -425,3 +779,17 @@ func (spt *ServicePrincipalToken) SetRefreshWithin(d time.Duration) {
// SetSender sets the http.Client used when obtaining the Service Principal token. An
// undecorated http.Client is used by default.
func (spt *ServicePrincipalToken) SetSender(s Sender) { spt.sender = s }
// OAuthToken implements the OAuthTokenProvider interface. It returns the current access token.
func (spt *ServicePrincipalToken) OAuthToken() string {
spt.refreshLock.RLock()
defer spt.refreshLock.RUnlock()
return spt.token.OAuthToken()
}
// Token returns a copy of the current token.
func (spt *ServicePrincipalToken) Token() Token {
spt.refreshLock.RLock()
defer spt.refreshLock.RUnlock()
return spt.token
}

View File

@@ -1,654 +0,0 @@
package adal
// Copyright 2017 Microsoft Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
import (
"crypto/rand"
"crypto/rsa"
"crypto/x509"
"crypto/x509/pkix"
"fmt"
"io/ioutil"
"math/big"
"net/http"
"net/url"
"os"
"reflect"
"strconv"
"strings"
"testing"
"time"
"github.com/Azure/go-autorest/autorest/date"
"github.com/Azure/go-autorest/autorest/mocks"
)
const (
defaultFormData = "client_id=id&client_secret=secret&grant_type=client_credentials&resource=resource"
defaultManualFormData = "client_id=id&grant_type=refresh_token&refresh_token=refreshtoken&resource=resource"
)
func TestTokenExpires(t *testing.T) {
tt := time.Now().Add(5 * time.Second)
tk := newTokenExpiresAt(tt)
if tk.Expires().Equal(tt) {
t.Fatalf("adal: Token#Expires miscalculated expiration time -- received %v, expected %v", tk.Expires(), tt)
}
}
func TestTokenIsExpired(t *testing.T) {
tk := newTokenExpiresAt(time.Now().Add(-5 * time.Second))
if !tk.IsExpired() {
t.Fatalf("adal: Token#IsExpired failed to mark a stale token as expired -- now %v, token expires at %v",
time.Now().UTC(), tk.Expires())
}
}
func TestTokenIsExpiredUninitialized(t *testing.T) {
tk := &Token{}
if !tk.IsExpired() {
t.Fatalf("adal: An uninitialized Token failed to mark itself as expired (expiration time %v)", tk.Expires())
}
}
func TestTokenIsNoExpired(t *testing.T) {
tk := newTokenExpiresAt(time.Now().Add(1000 * time.Second))
if tk.IsExpired() {
t.Fatalf("adal: Token marked a fresh token as expired -- now %v, token expires at %v", time.Now().UTC(), tk.Expires())
}
}
func TestTokenWillExpireIn(t *testing.T) {
d := 5 * time.Second
tk := newTokenExpiresIn(d)
if !tk.WillExpireIn(d) {
t.Fatal("adal: Token#WillExpireIn mismeasured expiration time")
}
}
func TestServicePrincipalTokenSetAutoRefresh(t *testing.T) {
spt := newServicePrincipalToken()
if !spt.autoRefresh {
t.Fatal("adal: ServicePrincipalToken did not default to automatic token refreshing")
}
spt.SetAutoRefresh(false)
if spt.autoRefresh {
t.Fatal("adal: ServicePrincipalToken#SetAutoRefresh did not disable automatic token refreshing")
}
}
func TestServicePrincipalTokenSetRefreshWithin(t *testing.T) {
spt := newServicePrincipalToken()
if spt.refreshWithin != defaultRefresh {
t.Fatal("adal: ServicePrincipalToken did not correctly set the default refresh interval")
}
spt.SetRefreshWithin(2 * defaultRefresh)
if spt.refreshWithin != 2*defaultRefresh {
t.Fatal("adal: ServicePrincipalToken#SetRefreshWithin did not set the refresh interval")
}
}
func TestServicePrincipalTokenSetSender(t *testing.T) {
spt := newServicePrincipalToken()
c := &http.Client{}
spt.SetSender(c)
if !reflect.DeepEqual(c, spt.sender) {
t.Fatal("adal: ServicePrincipalToken#SetSender did not set the sender")
}
}
func TestServicePrincipalTokenRefreshUsesPOST(t *testing.T) {
spt := newServicePrincipalToken()
body := mocks.NewBody(newTokenJSON("test", "test"))
resp := mocks.NewResponseWithBodyAndStatus(body, http.StatusOK, "OK")
c := mocks.NewSender()
s := DecorateSender(c,
(func() SendDecorator {
return func(s Sender) Sender {
return SenderFunc(func(r *http.Request) (*http.Response, error) {
if r.Method != "POST" {
t.Fatalf("adal: ServicePrincipalToken#Refresh did not correctly set HTTP method -- expected %v, received %v", "POST", r.Method)
}
return resp, nil
})
}
})())
spt.SetSender(s)
err := spt.Refresh()
if err != nil {
t.Fatalf("adal: ServicePrincipalToken#Refresh returned an unexpected error (%v)", err)
}
if body.IsOpen() {
t.Fatalf("the response was not closed!")
}
}
func TestServicePrincipalTokenFromMSIRefreshUsesPOST(t *testing.T) {
resource := "https://resource"
cb := func(token Token) error { return nil }
spt, err := NewServicePrincipalTokenFromMSI("http://msiendpoint/", resource, cb)
if err != nil {
t.Fatalf("Failed to get MSI SPT: %v", err)
}
body := mocks.NewBody(newTokenJSON("test", "test"))
resp := mocks.NewResponseWithBodyAndStatus(body, http.StatusOK, "OK")
c := mocks.NewSender()
s := DecorateSender(c,
(func() SendDecorator {
return func(s Sender) Sender {
return SenderFunc(func(r *http.Request) (*http.Response, error) {
if r.Method != "POST" {
t.Fatalf("adal: ServicePrincipalToken#Refresh did not correctly set HTTP method -- expected %v, received %v", "POST", r.Method)
}
if h := r.Header.Get("Metadata"); h != "true" {
t.Fatalf("adal: ServicePrincipalToken#Refresh did not correctly set Metadata header for MSI")
}
return resp, nil
})
}
})())
spt.SetSender(s)
err = spt.Refresh()
if err != nil {
t.Fatalf("adal: ServicePrincipalToken#Refresh returned an unexpected error (%v)", err)
}
if body.IsOpen() {
t.Fatalf("the response was not closed!")
}
}
func TestServicePrincipalTokenRefreshSetsMimeType(t *testing.T) {
spt := newServicePrincipalToken()
body := mocks.NewBody(newTokenJSON("test", "test"))
resp := mocks.NewResponseWithBodyAndStatus(body, http.StatusOK, "OK")
c := mocks.NewSender()
s := DecorateSender(c,
(func() SendDecorator {
return func(s Sender) Sender {
return SenderFunc(func(r *http.Request) (*http.Response, error) {
if r.Header.Get(http.CanonicalHeaderKey("Content-Type")) != "application/x-www-form-urlencoded" {
t.Fatalf("adal: ServicePrincipalToken#Refresh did not correctly set Content-Type -- expected %v, received %v",
"application/x-form-urlencoded",
r.Header.Get(http.CanonicalHeaderKey("Content-Type")))
}
return resp, nil
})
}
})())
spt.SetSender(s)
err := spt.Refresh()
if err != nil {
t.Fatalf("adal: ServicePrincipalToken#Refresh returned an unexpected error (%v)", err)
}
}
func TestServicePrincipalTokenRefreshSetsURL(t *testing.T) {
spt := newServicePrincipalToken()
body := mocks.NewBody(newTokenJSON("test", "test"))
resp := mocks.NewResponseWithBodyAndStatus(body, http.StatusOK, "OK")
c := mocks.NewSender()
s := DecorateSender(c,
(func() SendDecorator {
return func(s Sender) Sender {
return SenderFunc(func(r *http.Request) (*http.Response, error) {
if r.URL.String() != TestOAuthConfig.TokenEndpoint.String() {
t.Fatalf("adal: ServicePrincipalToken#Refresh did not correctly set the URL -- expected %v, received %v",
TestOAuthConfig.TokenEndpoint, r.URL)
}
return resp, nil
})
}
})())
spt.SetSender(s)
err := spt.Refresh()
if err != nil {
t.Fatalf("adal: ServicePrincipalToken#Refresh returned an unexpected error (%v)", err)
}
}
func testServicePrincipalTokenRefreshSetsBody(t *testing.T, spt *ServicePrincipalToken, f func(*testing.T, []byte)) {
body := mocks.NewBody(newTokenJSON("test", "test"))
resp := mocks.NewResponseWithBodyAndStatus(body, http.StatusOK, "OK")
c := mocks.NewSender()
s := DecorateSender(c,
(func() SendDecorator {
return func(s Sender) Sender {
return SenderFunc(func(r *http.Request) (*http.Response, error) {
b, err := ioutil.ReadAll(r.Body)
if err != nil {
t.Fatalf("adal: Failed to read body of Service Principal token request (%v)", err)
}
f(t, b)
return resp, nil
})
}
})())
spt.SetSender(s)
err := spt.Refresh()
if err != nil {
t.Fatalf("adal: ServicePrincipalToken#Refresh returned an unexpected error (%v)", err)
}
}
func TestServicePrincipalTokenManualRefreshSetsBody(t *testing.T) {
sptManual := newServicePrincipalTokenManual()
testServicePrincipalTokenRefreshSetsBody(t, sptManual, func(t *testing.T, b []byte) {
if string(b) != defaultManualFormData {
t.Fatalf("adal: ServicePrincipalToken#Refresh did not correctly set the HTTP Request Body -- expected %v, received %v",
defaultManualFormData, string(b))
}
})
}
func TestServicePrincipalTokenCertficateRefreshSetsBody(t *testing.T) {
sptCert := newServicePrincipalTokenCertificate(t)
testServicePrincipalTokenRefreshSetsBody(t, sptCert, func(t *testing.T, b []byte) {
body := string(b)
values, _ := url.ParseQuery(body)
if values["client_assertion_type"][0] != "urn:ietf:params:oauth:client-assertion-type:jwt-bearer" ||
values["client_id"][0] != "id" ||
values["grant_type"][0] != "client_credentials" ||
values["resource"][0] != "resource" {
t.Fatalf("adal: ServicePrincipalTokenCertificate#Refresh did not correctly set the HTTP Request Body.")
}
})
}
func TestServicePrincipalTokenSecretRefreshSetsBody(t *testing.T) {
spt := newServicePrincipalToken()
testServicePrincipalTokenRefreshSetsBody(t, spt, func(t *testing.T, b []byte) {
if string(b) != defaultFormData {
t.Fatalf("adal: ServicePrincipalToken#Refresh did not correctly set the HTTP Request Body -- expected %v, received %v",
defaultFormData, string(b))
}
})
}
func TestServicePrincipalTokenRefreshClosesRequestBody(t *testing.T) {
spt := newServicePrincipalToken()
body := mocks.NewBody(newTokenJSON("test", "test"))
resp := mocks.NewResponseWithBodyAndStatus(body, http.StatusOK, "OK")
c := mocks.NewSender()
s := DecorateSender(c,
(func() SendDecorator {
return func(s Sender) Sender {
return SenderFunc(func(r *http.Request) (*http.Response, error) {
return resp, nil
})
}
})())
spt.SetSender(s)
err := spt.Refresh()
if err != nil {
t.Fatalf("adal: ServicePrincipalToken#Refresh returned an unexpected error (%v)", err)
}
if resp.Body.(*mocks.Body).IsOpen() {
t.Fatal("adal: ServicePrincipalToken#Refresh failed to close the HTTP Response Body")
}
}
func TestServicePrincipalTokenRefreshRejectsResponsesWithStatusNotOK(t *testing.T) {
spt := newServicePrincipalToken()
body := mocks.NewBody(newTokenJSON("test", "test"))
resp := mocks.NewResponseWithBodyAndStatus(body, http.StatusUnauthorized, "Unauthorized")
c := mocks.NewSender()
s := DecorateSender(c,
(func() SendDecorator {
return func(s Sender) Sender {
return SenderFunc(func(r *http.Request) (*http.Response, error) {
return resp, nil
})
}
})())
spt.SetSender(s)
err := spt.Refresh()
if err == nil {
t.Fatalf("adal: ServicePrincipalToken#Refresh should reject a response with status != %d", http.StatusOK)
}
}
func TestServicePrincipalTokenRefreshRejectsEmptyBody(t *testing.T) {
spt := newServicePrincipalToken()
c := mocks.NewSender()
s := DecorateSender(c,
(func() SendDecorator {
return func(s Sender) Sender {
return SenderFunc(func(r *http.Request) (*http.Response, error) {
return mocks.NewResponse(), nil
})
}
})())
spt.SetSender(s)
err := spt.Refresh()
if err == nil {
t.Fatal("adal: ServicePrincipalToken#Refresh should reject an empty token")
}
}
func TestServicePrincipalTokenRefreshPropagatesErrors(t *testing.T) {
spt := newServicePrincipalToken()
c := mocks.NewSender()
c.SetError(fmt.Errorf("Faux Error"))
spt.SetSender(c)
err := spt.Refresh()
if err == nil {
t.Fatal("adal: Failed to propagate the request error")
}
}
func TestServicePrincipalTokenRefreshReturnsErrorIfNotOk(t *testing.T) {
spt := newServicePrincipalToken()
c := mocks.NewSender()
c.AppendResponse(mocks.NewResponseWithStatus("401 NotAuthorized", http.StatusUnauthorized))
spt.SetSender(c)
err := spt.Refresh()
if err == nil {
t.Fatalf("adal: Failed to return an when receiving a status code other than HTTP %d", http.StatusOK)
}
}
func TestServicePrincipalTokenRefreshUnmarshals(t *testing.T) {
spt := newServicePrincipalToken()
expiresOn := strconv.Itoa(int(time.Now().Add(3600 * time.Second).Sub(date.UnixEpoch()).Seconds()))
j := newTokenJSON(expiresOn, "resource")
resp := mocks.NewResponseWithContent(j)
c := mocks.NewSender()
s := DecorateSender(c,
(func() SendDecorator {
return func(s Sender) Sender {
return SenderFunc(func(r *http.Request) (*http.Response, error) {
return resp, nil
})
}
})())
spt.SetSender(s)
err := spt.Refresh()
if err != nil {
t.Fatalf("adal: ServicePrincipalToken#Refresh returned an unexpected error (%v)", err)
} else if spt.AccessToken != "accessToken" ||
spt.ExpiresIn != "3600" ||
spt.ExpiresOn != expiresOn ||
spt.NotBefore != expiresOn ||
spt.Resource != "resource" ||
spt.Type != "Bearer" {
t.Fatalf("adal: ServicePrincipalToken#Refresh failed correctly unmarshal the JSON -- expected %v, received %v",
j, *spt)
}
}
func TestServicePrincipalTokenEnsureFreshRefreshes(t *testing.T) {
spt := newServicePrincipalToken()
expireToken(&spt.Token)
body := mocks.NewBody(newTokenJSON("test", "test"))
resp := mocks.NewResponseWithBodyAndStatus(body, http.StatusOK, "OK")
f := false
c := mocks.NewSender()
s := DecorateSender(c,
(func() SendDecorator {
return func(s Sender) Sender {
return SenderFunc(func(r *http.Request) (*http.Response, error) {
f = true
return resp, nil
})
}
})())
spt.SetSender(s)
err := spt.EnsureFresh()
if err != nil {
t.Fatalf("adal: ServicePrincipalToken#EnsureFresh returned an unexpected error (%v)", err)
}
if !f {
t.Fatal("adal: ServicePrincipalToken#EnsureFresh failed to call Refresh for stale token")
}
}
func TestServicePrincipalTokenEnsureFreshSkipsIfFresh(t *testing.T) {
spt := newServicePrincipalToken()
setTokenToExpireIn(&spt.Token, 1000*time.Second)
f := false
c := mocks.NewSender()
s := DecorateSender(c,
(func() SendDecorator {
return func(s Sender) Sender {
return SenderFunc(func(r *http.Request) (*http.Response, error) {
f = true
return mocks.NewResponse(), nil
})
}
})())
spt.SetSender(s)
err := spt.EnsureFresh()
if err != nil {
t.Fatalf("adal: ServicePrincipalToken#EnsureFresh returned an unexpected error (%v)", err)
}
if f {
t.Fatal("adal: ServicePrincipalToken#EnsureFresh invoked Refresh for fresh token")
}
}
func TestRefreshCallback(t *testing.T) {
callbackTriggered := false
spt := newServicePrincipalToken(func(Token) error {
callbackTriggered = true
return nil
})
expiresOn := strconv.Itoa(int(time.Now().Add(3600 * time.Second).Sub(date.UnixEpoch()).Seconds()))
sender := mocks.NewSender()
j := newTokenJSON(expiresOn, "resource")
sender.AppendResponse(mocks.NewResponseWithContent(j))
spt.SetSender(sender)
err := spt.Refresh()
if err != nil {
t.Fatalf("adal: ServicePrincipalToken#Refresh returned an unexpected error (%v)", err)
}
if !callbackTriggered {
t.Fatalf("adal: RefreshCallback failed to trigger call callback")
}
}
func TestRefreshCallbackErrorPropagates(t *testing.T) {
errorText := "this is an error text"
spt := newServicePrincipalToken(func(Token) error {
return fmt.Errorf(errorText)
})
expiresOn := strconv.Itoa(int(time.Now().Add(3600 * time.Second).Sub(date.UnixEpoch()).Seconds()))
sender := mocks.NewSender()
j := newTokenJSON(expiresOn, "resource")
sender.AppendResponse(mocks.NewResponseWithContent(j))
spt.SetSender(sender)
err := spt.Refresh()
if err == nil || !strings.Contains(err.Error(), errorText) {
t.Fatalf("adal: RefreshCallback failed to propagate error")
}
}
// This demonstrates the danger of manual token without a refresh token
func TestServicePrincipalTokenManualRefreshFailsWithoutRefresh(t *testing.T) {
spt := newServicePrincipalTokenManual()
spt.RefreshToken = ""
err := spt.Refresh()
if err == nil {
t.Fatalf("adal: ServicePrincipalToken#Refresh should have failed with a ManualTokenSecret without a refresh token")
}
}
func TestNewServicePrincipalTokenFromMSI(t *testing.T) {
resource := "https://resource"
cb := func(token Token) error { return nil }
spt, err := NewServicePrincipalTokenFromMSI("http://msiendpoint/", resource, cb)
if err != nil {
t.Fatalf("Failed to get MSI SPT: %v", err)
}
// check some of the SPT fields
if _, ok := spt.secret.(*ServicePrincipalMSISecret); !ok {
t.Fatal("SPT secret was not of MSI type")
}
if spt.resource != resource {
t.Fatal("SPT came back with incorrect resource")
}
if len(spt.refreshCallbacks) != 1 {
t.Fatal("SPT had incorrect refresh callbacks.")
}
}
func TestGetVMEndpoint(t *testing.T) {
tempSettingsFile, err := ioutil.TempFile("", "ManagedIdentity-Settings")
if err != nil {
t.Fatal("Couldn't write temp settings file")
}
defer os.Remove(tempSettingsFile.Name())
settingsContents := []byte(`{
"url": "http://msiendpoint/"
}`)
if _, err := tempSettingsFile.Write(settingsContents); err != nil {
t.Fatal("Couldn't fill temp settings file")
}
endpoint, err := getMSIVMEndpoint(tempSettingsFile.Name())
if err != nil {
t.Fatal("Coudn't get VM endpoint")
}
if endpoint != "http://msiendpoint/" {
t.Fatal("Didn't get correct endpoint")
}
}
func newToken() *Token {
return &Token{
AccessToken: "ASECRETVALUE",
Resource: "https://azure.microsoft.com/",
Type: "Bearer",
}
}
func newTokenJSON(expiresOn string, resource string) string {
return fmt.Sprintf(`{
"access_token" : "accessToken",
"expires_in" : "3600",
"expires_on" : "%s",
"not_before" : "%s",
"resource" : "%s",
"token_type" : "Bearer"
}`,
expiresOn, expiresOn, resource)
}
func newTokenExpiresIn(expireIn time.Duration) *Token {
return setTokenToExpireIn(newToken(), expireIn)
}
func newTokenExpiresAt(expireAt time.Time) *Token {
return setTokenToExpireAt(newToken(), expireAt)
}
func expireToken(t *Token) *Token {
return setTokenToExpireIn(t, 0)
}
func setTokenToExpireAt(t *Token, expireAt time.Time) *Token {
t.ExpiresIn = "3600"
t.ExpiresOn = strconv.Itoa(int(expireAt.Sub(date.UnixEpoch()).Seconds()))
t.NotBefore = t.ExpiresOn
return t
}
func setTokenToExpireIn(t *Token, expireIn time.Duration) *Token {
return setTokenToExpireAt(t, time.Now().Add(expireIn))
}
func newServicePrincipalToken(callbacks ...TokenRefreshCallback) *ServicePrincipalToken {
spt, _ := NewServicePrincipalToken(TestOAuthConfig, "id", "secret", "resource", callbacks...)
return spt
}
func newServicePrincipalTokenManual() *ServicePrincipalToken {
token := newToken()
token.RefreshToken = "refreshtoken"
spt, _ := NewServicePrincipalTokenFromManualToken(TestOAuthConfig, "id", "resource", *token)
return spt
}
func newServicePrincipalTokenCertificate(t *testing.T) *ServicePrincipalToken {
template := x509.Certificate{
SerialNumber: big.NewInt(0),
Subject: pkix.Name{CommonName: "test"},
BasicConstraintsValid: true,
}
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
t.Fatal(err)
}
certificateBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &privateKey.PublicKey, privateKey)
if err != nil {
t.Fatal(err)
}
certificate, err := x509.ParseCertificate(certificateBytes)
if err != nil {
t.Fatal(err)
}
spt, _ := NewServicePrincipalTokenFromCertificate(TestOAuthConfig, "id", certificate, privateKey, "resource")
return spt
}

View File

@@ -1,181 +0,0 @@
package autorest
// Copyright 2017 Microsoft Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
import (
"fmt"
"net/http"
"net/url"
"strings"
"github.com/Azure/go-autorest/autorest/adal"
)
const (
bearerChallengeHeader = "Www-Authenticate"
bearer = "Bearer"
tenantID = "tenantID"
)
// Authorizer is the interface that provides a PrepareDecorator used to supply request
// authorization. Most often, the Authorizer decorator runs last so it has access to the full
// state of the formed HTTP request.
type Authorizer interface {
WithAuthorization() PrepareDecorator
}
// NullAuthorizer implements a default, "do nothing" Authorizer.
type NullAuthorizer struct{}
// WithAuthorization returns a PrepareDecorator that does nothing.
func (na NullAuthorizer) WithAuthorization() PrepareDecorator {
return WithNothing()
}
// BearerAuthorizer implements the bearer authorization
type BearerAuthorizer struct {
tokenProvider adal.OAuthTokenProvider
}
// NewBearerAuthorizer crates a BearerAuthorizer using the given token provider
func NewBearerAuthorizer(tp adal.OAuthTokenProvider) *BearerAuthorizer {
return &BearerAuthorizer{tokenProvider: tp}
}
func (ba *BearerAuthorizer) withBearerAuthorization() PrepareDecorator {
return WithHeader(headerAuthorization, fmt.Sprintf("Bearer %s", ba.tokenProvider.OAuthToken()))
}
// WithAuthorization returns a PrepareDecorator that adds an HTTP Authorization header whose
// value is "Bearer " followed by the token.
//
// By default, the token will be automatically refreshed through the Refresher interface.
func (ba *BearerAuthorizer) WithAuthorization() PrepareDecorator {
return func(p Preparer) Preparer {
return PreparerFunc(func(r *http.Request) (*http.Request, error) {
refresher, ok := ba.tokenProvider.(adal.Refresher)
if ok {
err := refresher.EnsureFresh()
if err != nil {
return r, NewErrorWithError(err, "azure.BearerAuthorizer", "WithAuthorization", nil,
"Failed to refresh the Token for request to %s", r.URL)
}
}
return (ba.withBearerAuthorization()(p)).Prepare(r)
})
}
}
// BearerAuthorizerCallbackFunc is the authentication callback signature.
type BearerAuthorizerCallbackFunc func(tenantID, resource string) (*BearerAuthorizer, error)
// BearerAuthorizerCallback implements bearer authorization via a callback.
type BearerAuthorizerCallback struct {
sender Sender
callback BearerAuthorizerCallbackFunc
}
// NewBearerAuthorizerCallback creates a bearer authorization callback. The callback
// is invoked when the HTTP request is submitted.
func NewBearerAuthorizerCallback(sender Sender, callback BearerAuthorizerCallbackFunc) *BearerAuthorizerCallback {
if sender == nil {
sender = &http.Client{}
}
return &BearerAuthorizerCallback{sender: sender, callback: callback}
}
// WithAuthorization returns a PrepareDecorator that adds an HTTP Authorization header whose value
// is "Bearer " followed by the token. The BearerAuthorizer is obtained via a user-supplied callback.
//
// By default, the token will be automatically refreshed through the Refresher interface.
func (bacb *BearerAuthorizerCallback) WithAuthorization() PrepareDecorator {
return func(p Preparer) Preparer {
return PreparerFunc(func(r *http.Request) (*http.Request, error) {
// make a copy of the request and remove the body as it's not
// required and avoids us having to create a copy of it.
rCopy := *r
removeRequestBody(&rCopy)
resp, err := bacb.sender.Do(&rCopy)
if err == nil && resp.StatusCode == 401 {
defer resp.Body.Close()
if hasBearerChallenge(resp) {
bc, err := newBearerChallenge(resp)
if err != nil {
return r, err
}
if bacb.callback != nil {
ba, err := bacb.callback(bc.values[tenantID], bc.values["resource"])
if err != nil {
return r, err
}
return ba.WithAuthorization()(p).Prepare(r)
}
}
}
return r, err
})
}
}
// returns true if the HTTP response contains a bearer challenge
func hasBearerChallenge(resp *http.Response) bool {
authHeader := resp.Header.Get(bearerChallengeHeader)
if len(authHeader) == 0 || strings.Index(authHeader, bearer) < 0 {
return false
}
return true
}
type bearerChallenge struct {
values map[string]string
}
func newBearerChallenge(resp *http.Response) (bc bearerChallenge, err error) {
challenge := strings.TrimSpace(resp.Header.Get(bearerChallengeHeader))
trimmedChallenge := challenge[len(bearer)+1:]
// challenge is a set of key=value pairs that are comma delimited
pairs := strings.Split(trimmedChallenge, ",")
if len(pairs) < 1 {
err = fmt.Errorf("challenge '%s' contains no pairs", challenge)
return bc, err
}
bc.values = make(map[string]string)
for i := range pairs {
trimmedPair := strings.TrimSpace(pairs[i])
pair := strings.Split(trimmedPair, "=")
if len(pair) == 2 {
// remove the enclosing quotes
key := strings.Trim(pair[0], "\"")
value := strings.Trim(pair[1], "\"")
switch key {
case "authorization", "authorization_uri":
// strip the tenant ID from the authorization URL
asURL, err := url.Parse(value)
if err != nil {
return bc, err
}
bc.values[tenantID] = asURL.Path[1:]
default:
bc.values[key] = value
}
}
}
return bc, err
}

View File

@@ -1,188 +0,0 @@
package autorest
// Copyright 2017 Microsoft Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
import (
"fmt"
"net/http"
"reflect"
"testing"
"github.com/Azure/go-autorest/autorest/adal"
"github.com/Azure/go-autorest/autorest/mocks"
)
const (
TestTenantID = "TestTenantID"
TestActiveDirectoryEndpoint = "https://login/test.com/"
)
func TestWithAuthorizer(t *testing.T) {
r1 := mocks.NewRequest()
na := &NullAuthorizer{}
r2, err := Prepare(r1,
na.WithAuthorization())
if err != nil {
t.Fatalf("autorest: NullAuthorizer#WithAuthorization returned an unexpected error (%v)", err)
} else if !reflect.DeepEqual(r1, r2) {
t.Fatalf("autorest: NullAuthorizer#WithAuthorization modified the request -- received %v, expected %v", r2, r1)
}
}
func TestTokenWithAuthorization(t *testing.T) {
token := &adal.Token{
AccessToken: "TestToken",
Resource: "https://azure.microsoft.com/",
Type: "Bearer",
}
ba := NewBearerAuthorizer(token)
req, err := Prepare(&http.Request{}, ba.WithAuthorization())
if err != nil {
t.Fatalf("azure: BearerAuthorizer#WithAuthorization returned an error (%v)", err)
} else if req.Header.Get(http.CanonicalHeaderKey("Authorization")) != fmt.Sprintf("Bearer %s", token.AccessToken) {
t.Fatal("azure: BearerAuthorizer#WithAuthorization failed to set Authorization header")
}
}
func TestServicePrincipalTokenWithAuthorizationNoRefresh(t *testing.T) {
oauthConfig, err := adal.NewOAuthConfig(TestActiveDirectoryEndpoint, TestTenantID)
if err != nil {
t.Fatalf("azure: BearerAuthorizer#WithAuthorization returned an error (%v)", err)
}
spt, err := adal.NewServicePrincipalToken(*oauthConfig, "id", "secret", "resource", nil)
if err != nil {
t.Fatalf("azure: BearerAuthorizer#WithAuthorization returned an error (%v)", err)
}
spt.SetAutoRefresh(false)
s := mocks.NewSender()
spt.SetSender(s)
ba := NewBearerAuthorizer(spt)
req, err := Prepare(mocks.NewRequest(), ba.WithAuthorization())
if err != nil {
t.Fatalf("azure: BearerAuthorizer#WithAuthorization returned an error (%v)", err)
} else if req.Header.Get(http.CanonicalHeaderKey("Authorization")) != fmt.Sprintf("Bearer %s", spt.AccessToken) {
t.Fatal("azure: BearerAuthorizer#WithAuthorization failed to set Authorization header")
}
}
func TestServicePrincipalTokenWithAuthorizationRefresh(t *testing.T) {
oauthConfig, err := adal.NewOAuthConfig(TestActiveDirectoryEndpoint, TestTenantID)
if err != nil {
t.Fatalf("azure: BearerAuthorizer#WithAuthorization returned an error (%v)", err)
}
refreshed := false
spt, err := adal.NewServicePrincipalToken(*oauthConfig, "id", "secret", "resource", func(t adal.Token) error {
refreshed = true
return nil
})
if err != nil {
t.Fatalf("azure: BearerAuthorizer#WithAuthorization returned an error (%v)", err)
}
jwt := `{
"access_token" : "accessToken",
"expires_in" : "3600",
"expires_on" : "test",
"not_before" : "test",
"resource" : "test",
"token_type" : "Bearer"
}`
body := mocks.NewBody(jwt)
resp := mocks.NewResponseWithBodyAndStatus(body, http.StatusOK, "OK")
c := mocks.NewSender()
s := DecorateSender(c,
(func() SendDecorator {
return func(s Sender) Sender {
return SenderFunc(func(r *http.Request) (*http.Response, error) {
return resp, nil
})
}
})())
spt.SetSender(s)
ba := NewBearerAuthorizer(spt)
req, err := Prepare(mocks.NewRequest(), ba.WithAuthorization())
if err != nil {
t.Fatalf("azure: BearerAuthorizer#WithAuthorization returned an error (%v)", err)
} else if req.Header.Get(http.CanonicalHeaderKey("Authorization")) != fmt.Sprintf("Bearer %s", spt.AccessToken) {
t.Fatal("azure: BearerAuthorizer#WithAuthorization failed to set Authorization header")
}
if !refreshed {
t.Fatal("azure: BearerAuthorizer#WithAuthorization must refresh the token")
}
}
func TestServicePrincipalTokenWithAuthorizationReturnsErrorIfConnotRefresh(t *testing.T) {
oauthConfig, err := adal.NewOAuthConfig(TestActiveDirectoryEndpoint, TestTenantID)
if err != nil {
t.Fatalf("azure: BearerAuthorizer#WithAuthorization returned an error (%v)", err)
}
spt, err := adal.NewServicePrincipalToken(*oauthConfig, "id", "secret", "resource", nil)
if err != nil {
t.Fatalf("azure: BearerAuthorizer#WithAuthorization returned an error (%v)", err)
}
s := mocks.NewSender()
s.AppendResponse(mocks.NewResponseWithStatus("400 Bad Request", http.StatusBadRequest))
spt.SetSender(s)
ba := NewBearerAuthorizer(spt)
_, err = Prepare(mocks.NewRequest(), ba.WithAuthorization())
if err == nil {
t.Fatal("azure: BearerAuthorizer#WithAuthorization failed to return an error when refresh fails")
}
}
func TestBearerAuthorizerCallback(t *testing.T) {
tenantString := "123-tenantID-456"
resourceString := "https://fake.resource.net"
s := mocks.NewSender()
resp := mocks.NewResponseWithStatus("401 Unauthorized", http.StatusUnauthorized)
mocks.SetResponseHeader(resp, bearerChallengeHeader, bearer+" \"authorization\"=\"https://fake.net/"+tenantString+"\",\"resource\"=\""+resourceString+"\"")
s.AppendResponse(resp)
auth := NewBearerAuthorizerCallback(s, func(tenantID, resource string) (*BearerAuthorizer, error) {
if tenantID != tenantString {
t.Fatal("BearerAuthorizerCallback: bad tenant ID")
}
if resource != resourceString {
t.Fatal("BearerAuthorizerCallback: bad resource")
}
oauthConfig, err := adal.NewOAuthConfig(TestActiveDirectoryEndpoint, tenantID)
if err != nil {
t.Fatalf("azure: NewOAuthConfig returned an error (%v)", err)
}
spt, err := adal.NewServicePrincipalToken(*oauthConfig, "id", "secret", resource)
if err != nil {
t.Fatalf("azure: NewServicePrincipalToken returned an error (%v)", err)
}
spt.SetSender(s)
return NewBearerAuthorizer(spt), nil
})
_, err := Prepare(mocks.NewRequest(), auth.WithAuthorization())
if err == nil {
t.Fatal("azure: BearerAuthorizerCallback#WithAuthorization failed to return an error when refresh fails")
}
}

View File

@@ -1,132 +0,0 @@
/*
Package autorest implements an HTTP request pipeline suitable for use across multiple go-routines
and provides the shared routines relied on by AutoRest (see https://github.com/Azure/autorest/)
generated Go code.
The package breaks sending and responding to HTTP requests into three phases: Preparing, Sending,
and Responding. A typical pattern is:
req, err := Prepare(&http.Request{},
token.WithAuthorization())
resp, err := Send(req,
WithLogging(logger),
DoErrorIfStatusCode(http.StatusInternalServerError),
DoCloseIfError(),
DoRetryForAttempts(5, time.Second))
err = Respond(resp,
ByDiscardingBody(),
ByClosing())
Each phase relies on decorators to modify and / or manage processing. Decorators may first modify
and then pass the data along, pass the data first and then modify the result, or wrap themselves
around passing the data (such as a logger might do). Decorators run in the order provided. For
example, the following:
req, err := Prepare(&http.Request{},
WithBaseURL("https://microsoft.com/"),
WithPath("a"),
WithPath("b"),
WithPath("c"))
will set the URL to:
https://microsoft.com/a/b/c
Preparers and Responders may be shared and re-used (assuming the underlying decorators support
sharing and re-use). Performant use is obtained by creating one or more Preparers and Responders
shared among multiple go-routines, and a single Sender shared among multiple sending go-routines,
all bound together by means of input / output channels.
Decorators hold their passed state within a closure (such as the path components in the example
above). Be careful to share Preparers and Responders only in a context where such held state
applies. For example, it may not make sense to share a Preparer that applies a query string from a
fixed set of values. Similarly, sharing a Responder that reads the response body into a passed
struct (e.g., ByUnmarshallingJson) is likely incorrect.
Lastly, the Swagger specification (https://swagger.io) that drives AutoRest
(https://github.com/Azure/autorest/) precisely defines two date forms: date and date-time. The
github.com/Azure/go-autorest/autorest/date package provides time.Time derivations to ensure
correct parsing and formatting.
Errors raised by autorest objects and methods will conform to the autorest.Error interface.
See the included examples for more detail. For details on the suggested use of this package by
generated clients, see the Client described below.
*/
package autorest
// Copyright 2017 Microsoft Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
import (
"net/http"
"time"
)
const (
// HeaderLocation specifies the HTTP Location header.
HeaderLocation = "Location"
// HeaderRetryAfter specifies the HTTP Retry-After header.
HeaderRetryAfter = "Retry-After"
)
// ResponseHasStatusCode returns true if the status code in the HTTP Response is in the passed set
// and false otherwise.
func ResponseHasStatusCode(resp *http.Response, codes ...int) bool {
if resp == nil {
return false
}
return containsInt(codes, resp.StatusCode)
}
// GetLocation retrieves the URL from the Location header of the passed response.
func GetLocation(resp *http.Response) string {
return resp.Header.Get(HeaderLocation)
}
// GetRetryAfter extracts the retry delay from the Retry-After header of the passed response. If
// the header is absent or is malformed, it will return the supplied default delay time.Duration.
func GetRetryAfter(resp *http.Response, defaultDelay time.Duration) time.Duration {
retry := resp.Header.Get(HeaderRetryAfter)
if retry == "" {
return defaultDelay
}
d, err := time.ParseDuration(retry + "s")
if err != nil {
return defaultDelay
}
return d
}
// NewPollingRequest allocates and returns a new http.Request to poll for the passed response.
func NewPollingRequest(resp *http.Response, cancel <-chan struct{}) (*http.Request, error) {
location := GetLocation(resp)
if location == "" {
return nil, NewErrorWithResponse("autorest", "NewPollingRequest", resp, "Location header missing from response that requires polling")
}
req, err := Prepare(&http.Request{Cancel: cancel},
AsGet(),
WithBaseURL(location))
if err != nil {
return nil, NewErrorWithError(err, "autorest", "NewPollingRequest", nil, "Failure creating poll request to %s", location)
}
return req, nil
}

View File

@@ -1,140 +0,0 @@
package autorest
// Copyright 2017 Microsoft Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
import (
"net/http"
"testing"
"github.com/Azure/go-autorest/autorest/mocks"
)
func TestResponseHasStatusCode(t *testing.T) {
codes := []int{http.StatusOK, http.StatusAccepted}
resp := &http.Response{StatusCode: http.StatusAccepted}
if !ResponseHasStatusCode(resp, codes...) {
t.Fatalf("autorest: ResponseHasStatusCode failed to find %v in %v", resp.StatusCode, codes)
}
}
func TestResponseHasStatusCodeNotPresent(t *testing.T) {
codes := []int{http.StatusOK, http.StatusAccepted}
resp := &http.Response{StatusCode: http.StatusInternalServerError}
if ResponseHasStatusCode(resp, codes...) {
t.Fatalf("autorest: ResponseHasStatusCode unexpectedly found %v in %v", resp.StatusCode, codes)
}
}
func TestNewPollingRequestDoesNotReturnARequestWhenLocationHeaderIsMissing(t *testing.T) {
resp := mocks.NewResponseWithStatus("500 InternalServerError", http.StatusInternalServerError)
req, _ := NewPollingRequest(resp, nil)
if req != nil {
t.Fatal("autorest: NewPollingRequest returned an http.Request when the Location header was missing")
}
}
func TestNewPollingRequestReturnsAnErrorWhenPrepareFails(t *testing.T) {
resp := mocks.NewResponseWithStatus("202 Accepted", http.StatusAccepted)
mocks.SetAcceptedHeaders(resp)
resp.Header.Set(http.CanonicalHeaderKey(HeaderLocation), mocks.TestBadURL)
_, err := NewPollingRequest(resp, nil)
if err == nil {
t.Fatal("autorest: NewPollingRequest failed to return an error when Prepare fails")
}
}
func TestNewPollingRequestDoesNotReturnARequestWhenPrepareFails(t *testing.T) {
resp := mocks.NewResponseWithStatus("202 Accepted", http.StatusAccepted)
mocks.SetAcceptedHeaders(resp)
resp.Header.Set(http.CanonicalHeaderKey(HeaderLocation), mocks.TestBadURL)
req, _ := NewPollingRequest(resp, nil)
if req != nil {
t.Fatal("autorest: NewPollingRequest returned an http.Request when Prepare failed")
}
}
func TestNewPollingRequestReturnsAGetRequest(t *testing.T) {
resp := mocks.NewResponseWithStatus("202 Accepted", http.StatusAccepted)
mocks.SetAcceptedHeaders(resp)
req, _ := NewPollingRequest(resp, nil)
if req.Method != "GET" {
t.Fatalf("autorest: NewPollingRequest did not create an HTTP GET request -- actual method %v", req.Method)
}
}
func TestNewPollingRequestProvidesTheURL(t *testing.T) {
resp := mocks.NewResponseWithStatus("202 Accepted", http.StatusAccepted)
mocks.SetAcceptedHeaders(resp)
req, _ := NewPollingRequest(resp, nil)
if req.URL.String() != mocks.TestURL {
t.Fatalf("autorest: NewPollingRequest did not create an HTTP with the expected URL -- received %v, expected %v", req.URL, mocks.TestURL)
}
}
func TestGetLocation(t *testing.T) {
resp := mocks.NewResponseWithStatus("202 Accepted", http.StatusAccepted)
mocks.SetAcceptedHeaders(resp)
l := GetLocation(resp)
if len(l) == 0 {
t.Fatalf("autorest: GetLocation failed to return Location header -- expected %v, received %v", mocks.TestURL, l)
}
}
func TestGetLocationReturnsEmptyStringForMissingLocation(t *testing.T) {
resp := mocks.NewResponseWithStatus("202 Accepted", http.StatusAccepted)
l := GetLocation(resp)
if len(l) != 0 {
t.Fatalf("autorest: GetLocation return a value without a Location header -- received %v", l)
}
}
func TestGetRetryAfter(t *testing.T) {
resp := mocks.NewResponseWithStatus("202 Accepted", http.StatusAccepted)
mocks.SetAcceptedHeaders(resp)
d := GetRetryAfter(resp, DefaultPollingDelay)
if d != mocks.TestDelay {
t.Fatalf("autorest: GetRetryAfter failed to returned the expected delay -- expected %v, received %v", mocks.TestDelay, d)
}
}
func TestGetRetryAfterReturnsDefaultDelayIfRetryHeaderIsMissing(t *testing.T) {
resp := mocks.NewResponseWithStatus("202 Accepted", http.StatusAccepted)
d := GetRetryAfter(resp, DefaultPollingDelay)
if d != DefaultPollingDelay {
t.Fatalf("autorest: GetRetryAfter failed to returned the default delay for a missing Retry-After header -- expected %v, received %v",
DefaultPollingDelay, d)
}
}
func TestGetRetryAfterReturnsDefaultDelayIfRetryHeaderIsMalformed(t *testing.T) {
resp := mocks.NewResponseWithStatus("202 Accepted", http.StatusAccepted)
mocks.SetAcceptedHeaders(resp)
resp.Header.Set(http.CanonicalHeaderKey(HeaderRetryAfter), "a very bad non-integer value")
d := GetRetryAfter(resp, DefaultPollingDelay)
if d != DefaultPollingDelay {
t.Fatalf("autorest: GetRetryAfter failed to returned the default delay for a malformed Retry-After header -- expected %v, received %v",
DefaultPollingDelay, d)
}
}

View File

@@ -1,460 +0,0 @@
package azure
// Copyright 2017 Microsoft Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io/ioutil"
"net/http"
"strings"
"time"
"github.com/Azure/go-autorest/autorest"
"github.com/Azure/go-autorest/autorest/date"
)
const (
headerAsyncOperation = "Azure-AsyncOperation"
)
const (
operationInProgress string = "InProgress"
operationCanceled string = "Canceled"
operationFailed string = "Failed"
operationSucceeded string = "Succeeded"
)
var pollingCodes = [...]int{http.StatusAccepted, http.StatusCreated, http.StatusOK}
// Future provides a mechanism to access the status and results of an asynchronous request.
// Since futures are stateful they should be passed by value to avoid race conditions.
type Future struct {
req *http.Request
resp *http.Response
ps pollingState
}
// NewFuture returns a new Future object initialized with the specified request.
func NewFuture(req *http.Request) Future {
return Future{req: req}
}
// Response returns the last HTTP response or nil if there isn't one.
func (f Future) Response() *http.Response {
return f.resp
}
// Status returns the last status message of the operation.
func (f Future) Status() string {
if f.ps.State == "" {
return "Unknown"
}
return f.ps.State
}
// PollingMethod returns the method used to monitor the status of the asynchronous operation.
func (f Future) PollingMethod() PollingMethodType {
return f.ps.PollingMethod
}
// Done queries the service to see if the operation has completed.
func (f *Future) Done(sender autorest.Sender) (bool, error) {
// exit early if this future has terminated
if f.ps.hasTerminated() {
return true, f.errorInfo()
}
resp, err := sender.Do(f.req)
f.resp = resp
if err != nil || !autorest.ResponseHasStatusCode(resp, pollingCodes[:]...) {
return false, err
}
err = updatePollingState(resp, &f.ps)
if err != nil {
return false, err
}
if f.ps.hasTerminated() {
return true, f.errorInfo()
}
f.req, err = newPollingRequest(f.ps)
return false, err
}
// GetPollingDelay returns a duration the application should wait before checking
// the status of the asynchronous request and true; this value is returned from
// the service via the Retry-After response header. If the header wasn't returned
// then the function returns the zero-value time.Duration and false.
func (f Future) GetPollingDelay() (time.Duration, bool) {
if f.resp == nil {
return 0, false
}
retry := f.resp.Header.Get(autorest.HeaderRetryAfter)
if retry == "" {
return 0, false
}
d, err := time.ParseDuration(retry + "s")
if err != nil {
panic(err)
}
return d, true
}
// WaitForCompletion will return when one of the following conditions is met: the long
// running operation has completed, the provided context is cancelled, or the client's
// polling duration has been exceeded. It will retry failed polling attempts based on
// the retry value defined in the client up to the maximum retry attempts.
func (f Future) WaitForCompletion(ctx context.Context, client autorest.Client) error {
ctx, cancel := context.WithTimeout(ctx, client.PollingDuration)
defer cancel()
done, err := f.Done(client)
for attempts := 0; !done; done, err = f.Done(client) {
if attempts >= client.RetryAttempts {
return autorest.NewErrorWithError(err, "azure", "WaitForCompletion", f.resp, "the number of retries has been exceeded")
}
// we want delayAttempt to be zero in the non-error case so
// that DelayForBackoff doesn't perform exponential back-off
var delayAttempt int
var delay time.Duration
if err == nil {
// check for Retry-After delay, if not present use the client's polling delay
var ok bool
delay, ok = f.GetPollingDelay()
if !ok {
delay = client.PollingDelay
}
} else {
// there was an error polling for status so perform exponential
// back-off based on the number of attempts using the client's retry
// duration. update attempts after delayAttempt to avoid off-by-one.
delayAttempt = attempts
delay = client.RetryDuration
attempts++
}
// wait until the delay elapses or the context is cancelled
delayElapsed := autorest.DelayForBackoff(delay, delayAttempt, ctx.Done())
if !delayElapsed {
return autorest.NewErrorWithError(ctx.Err(), "azure", "WaitForCompletion", f.resp, "context has been cancelled")
}
}
return err
}
// if the operation failed the polling state will contain
// error information and implements the error interface
func (f *Future) errorInfo() error {
if !f.ps.hasSucceeded() {
return f.ps
}
return nil
}
// MarshalJSON implements the json.Marshaler interface.
func (f Future) MarshalJSON() ([]byte, error) {
return json.Marshal(&f.ps)
}
// UnmarshalJSON implements the json.Unmarshaler interface.
func (f *Future) UnmarshalJSON(data []byte) error {
err := json.Unmarshal(data, &f.ps)
if err != nil {
return err
}
f.req, err = newPollingRequest(f.ps)
return err
}
// DoPollForAsynchronous returns a SendDecorator that polls if the http.Response is for an Azure
// long-running operation. It will delay between requests for the duration specified in the
// RetryAfter header or, if the header is absent, the passed delay. Polling may be canceled by
// closing the optional channel on the http.Request.
func DoPollForAsynchronous(delay time.Duration) autorest.SendDecorator {
return func(s autorest.Sender) autorest.Sender {
return autorest.SenderFunc(func(r *http.Request) (resp *http.Response, err error) {
resp, err = s.Do(r)
if err != nil {
return resp, err
}
if !autorest.ResponseHasStatusCode(resp, pollingCodes[:]...) {
return resp, nil
}
ps := pollingState{}
for err == nil {
err = updatePollingState(resp, &ps)
if err != nil {
break
}
if ps.hasTerminated() {
if !ps.hasSucceeded() {
err = ps
}
break
}
r, err = newPollingRequest(ps)
if err != nil {
return resp, err
}
r.Cancel = resp.Request.Cancel
delay = autorest.GetRetryAfter(resp, delay)
resp, err = autorest.SendWithSender(s, r,
autorest.AfterDelay(delay))
}
return resp, err
})
}
}
func getAsyncOperation(resp *http.Response) string {
return resp.Header.Get(http.CanonicalHeaderKey(headerAsyncOperation))
}
func hasSucceeded(state string) bool {
return strings.EqualFold(state, operationSucceeded)
}
func hasTerminated(state string) bool {
return strings.EqualFold(state, operationCanceled) || strings.EqualFold(state, operationFailed) || strings.EqualFold(state, operationSucceeded)
}
func hasFailed(state string) bool {
return strings.EqualFold(state, operationFailed)
}
type provisioningTracker interface {
state() string
hasSucceeded() bool
hasTerminated() bool
}
type operationResource struct {
// Note:
// The specification states services should return the "id" field. However some return it as
// "operationId".
ID string `json:"id"`
OperationID string `json:"operationId"`
Name string `json:"name"`
Status string `json:"status"`
Properties map[string]interface{} `json:"properties"`
OperationError ServiceError `json:"error"`
StartTime date.Time `json:"startTime"`
EndTime date.Time `json:"endTime"`
PercentComplete float64 `json:"percentComplete"`
}
func (or operationResource) state() string {
return or.Status
}
func (or operationResource) hasSucceeded() bool {
return hasSucceeded(or.state())
}
func (or operationResource) hasTerminated() bool {
return hasTerminated(or.state())
}
type provisioningProperties struct {
ProvisioningState string `json:"provisioningState"`
}
type provisioningStatus struct {
Properties provisioningProperties `json:"properties,omitempty"`
ProvisioningError ServiceError `json:"error,omitempty"`
}
func (ps provisioningStatus) state() string {
return ps.Properties.ProvisioningState
}
func (ps provisioningStatus) hasSucceeded() bool {
return hasSucceeded(ps.state())
}
func (ps provisioningStatus) hasTerminated() bool {
return hasTerminated(ps.state())
}
func (ps provisioningStatus) hasProvisioningError() bool {
return ps.ProvisioningError != ServiceError{}
}
// PollingMethodType defines a type used for enumerating polling mechanisms.
type PollingMethodType string
const (
// PollingAsyncOperation indicates the polling method uses the Azure-AsyncOperation header.
PollingAsyncOperation PollingMethodType = "AsyncOperation"
// PollingLocation indicates the polling method uses the Location header.
PollingLocation PollingMethodType = "Location"
// PollingUnknown indicates an unknown polling method and is the default value.
PollingUnknown PollingMethodType = ""
)
type pollingState struct {
PollingMethod PollingMethodType `json:"pollingMethod"`
URI string `json:"uri"`
State string `json:"state"`
Code string `json:"code"`
Message string `json:"message"`
}
func (ps pollingState) hasSucceeded() bool {
return hasSucceeded(ps.State)
}
func (ps pollingState) hasTerminated() bool {
return hasTerminated(ps.State)
}
func (ps pollingState) hasFailed() bool {
return hasFailed(ps.State)
}
func (ps pollingState) Error() string {
return fmt.Sprintf("Long running operation terminated with status '%s': Code=%q Message=%q", ps.State, ps.Code, ps.Message)
}
// updatePollingState maps the operation status -- retrieved from either a provisioningState
// field, the status field of an OperationResource, or inferred from the HTTP status code --
// into a well-known states. Since the process begins from the initial request, the state
// always comes from either a the provisioningState returned or is inferred from the HTTP
// status code. Subsequent requests will read an Azure OperationResource object if the
// service initially returned the Azure-AsyncOperation header. The responseFormat field notes
// the expected response format.
func updatePollingState(resp *http.Response, ps *pollingState) error {
// Determine the response shape
// -- The first response will always be a provisioningStatus response; only the polling requests,
// depending on the header returned, may be something otherwise.
var pt provisioningTracker
if ps.PollingMethod == PollingAsyncOperation {
pt = &operationResource{}
} else {
pt = &provisioningStatus{}
}
// If this is the first request (that is, the polling response shape is unknown), determine how
// to poll and what to expect
if ps.PollingMethod == PollingUnknown {
req := resp.Request
if req == nil {
return autorest.NewError("azure", "updatePollingState", "Azure Polling Error - Original HTTP request is missing")
}
// Prefer the Azure-AsyncOperation header
ps.URI = getAsyncOperation(resp)
if ps.URI != "" {
ps.PollingMethod = PollingAsyncOperation
} else {
ps.PollingMethod = PollingLocation
}
// Else, use the Location header
if ps.URI == "" {
ps.URI = autorest.GetLocation(resp)
}
// Lastly, requests against an existing resource, use the last request URI
if ps.URI == "" {
m := strings.ToUpper(req.Method)
if m == http.MethodPatch || m == http.MethodPut || m == http.MethodGet {
ps.URI = req.URL.String()
}
}
}
// Read and interpret the response (saving the Body in case no polling is necessary)
b := &bytes.Buffer{}
err := autorest.Respond(resp,
autorest.ByCopying(b),
autorest.ByUnmarshallingJSON(pt),
autorest.ByClosing())
resp.Body = ioutil.NopCloser(b)
if err != nil {
return err
}
// Interpret the results
// -- Terminal states apply regardless
// -- Unknown states are per-service inprogress states
// -- Otherwise, infer state from HTTP status code
if pt.hasTerminated() {
ps.State = pt.state()
} else if pt.state() != "" {
ps.State = operationInProgress
} else {
switch resp.StatusCode {
case http.StatusAccepted:
ps.State = operationInProgress
case http.StatusNoContent, http.StatusCreated, http.StatusOK:
ps.State = operationSucceeded
default:
ps.State = operationFailed
}
}
if strings.EqualFold(ps.State, operationInProgress) && ps.URI == "" {
return autorest.NewError("azure", "updatePollingState", "Azure Polling Error - Unable to obtain polling URI for %s %s", resp.Request.Method, resp.Request.URL)
}
// For failed operation, check for error code and message in
// -- Operation resource
// -- Response
// -- Otherwise, Unknown
if ps.hasFailed() {
if ps.PollingMethod == PollingAsyncOperation {
or := pt.(*operationResource)
ps.Code = or.OperationError.Code
ps.Message = or.OperationError.Message
} else {
p := pt.(*provisioningStatus)
if p.hasProvisioningError() {
ps.Code = p.ProvisioningError.Code
ps.Message = p.ProvisioningError.Message
} else {
ps.Code = "Unknown"
ps.Message = "None"
}
}
}
return nil
}
func newPollingRequest(ps pollingState) (*http.Request, error) {
reqPoll, err := autorest.Prepare(&http.Request{},
autorest.AsGet(),
autorest.WithBaseURL(ps.URI))
if err != nil {
return nil, autorest.NewErrorWithError(err, "azure", "newPollingRequest", nil, "Failure creating poll request to %s", ps.URI)
}
return reqPoll, nil
}

File diff suppressed because it is too large Load Diff

View File

@@ -1,143 +0,0 @@
// Copyright 2017 Microsoft Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package auth
import (
"bytes"
"encoding/binary"
"encoding/json"
"errors"
"fmt"
"io/ioutil"
"os"
"strings"
"unicode/utf16"
"github.com/Azure/go-autorest/autorest"
"github.com/Azure/go-autorest/autorest/adal"
"github.com/Azure/go-autorest/autorest/azure"
"github.com/dimchansky/utfbom"
)
// ClientSetup includes authentication details and cloud specific
// parameters for ARM clients
type ClientSetup struct {
*autorest.BearerAuthorizer
File
BaseURI string
}
// File represents the authentication file
type File struct {
ClientID string `json:"clientId,omitempty"`
ClientSecret string `json:"clientSecret,omitempty"`
SubscriptionID string `json:"subscriptionId,omitempty"`
TenantID string `json:"tenantId,omitempty"`
ActiveDirectoryEndpoint string `json:"activeDirectoryEndpointUrl,omitempty"`
ResourceManagerEndpoint string `json:"resourceManagerEndpointUrl,omitempty"`
GraphResourceID string `json:"activeDirectoryGraphResourceId,omitempty"`
SQLManagementEndpoint string `json:"sqlManagementEndpointUrl,omitempty"`
GalleryEndpoint string `json:"galleryEndpointUrl,omitempty"`
ManagementEndpoint string `json:"managementEndpointUrl,omitempty"`
}
// GetClientSetup provides an authorizer, base URI, subscriptionID and
// tenantID parameters from an Azure CLI auth file
func GetClientSetup(baseURI string) (auth ClientSetup, err error) {
fileLocation := os.Getenv("AZURE_AUTH_LOCATION")
if fileLocation == "" {
return auth, errors.New("auth file not found. Environment variable AZURE_AUTH_LOCATION is not set")
}
contents, err := ioutil.ReadFile(fileLocation)
if err != nil {
return
}
// Auth file might be encoded
decoded, err := decode(contents)
if err != nil {
return
}
err = json.Unmarshal(decoded, &auth.File)
if err != nil {
return
}
resource, err := getResourceForToken(auth.File, baseURI)
if err != nil {
return
}
auth.BaseURI = resource
config, err := adal.NewOAuthConfig(auth.ActiveDirectoryEndpoint, auth.TenantID)
if err != nil {
return
}
spToken, err := adal.NewServicePrincipalToken(*config, auth.ClientID, auth.ClientSecret, resource)
if err != nil {
return
}
auth.BearerAuthorizer = autorest.NewBearerAuthorizer(spToken)
return
}
func decode(b []byte) ([]byte, error) {
reader, enc := utfbom.Skip(bytes.NewReader(b))
switch enc {
case utfbom.UTF16LittleEndian:
u16 := make([]uint16, (len(b)/2)-1)
err := binary.Read(reader, binary.LittleEndian, &u16)
if err != nil {
return nil, err
}
return []byte(string(utf16.Decode(u16))), nil
case utfbom.UTF16BigEndian:
u16 := make([]uint16, (len(b)/2)-1)
err := binary.Read(reader, binary.BigEndian, &u16)
if err != nil {
return nil, err
}
return []byte(string(utf16.Decode(u16))), nil
}
return ioutil.ReadAll(reader)
}
func getResourceForToken(f File, baseURI string) (string, error) {
// Compare dafault base URI from the SDK to the endpoints from the public cloud
// Base URI and token resource are the same string. This func finds the authentication
// file field that matches the SDK base URI. The SDK defines the public cloud
// endpoint as its default base URI
if !strings.HasSuffix(baseURI, "/") {
baseURI += "/"
}
switch baseURI {
case azure.PublicCloud.ServiceManagementEndpoint:
return f.ManagementEndpoint, nil
case azure.PublicCloud.ResourceManagerEndpoint:
return f.ResourceManagerEndpoint, nil
case azure.PublicCloud.ActiveDirectoryEndpoint:
return f.ActiveDirectoryEndpoint, nil
case azure.PublicCloud.GalleryEndpoint:
return f.GalleryEndpoint, nil
case azure.PublicCloud.GraphEndpoint:
return f.GraphResourceID, nil
}
return "", fmt.Errorf("auth: base URI not found in endpoints")
}

View File

@@ -1,111 +0,0 @@
// Copyright 2017 Microsoft Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package auth
import (
"encoding/json"
"io/ioutil"
"os"
"path/filepath"
"reflect"
"testing"
)
var (
expectedFile = File{
ClientID: "client-id-123",
ClientSecret: "client-secret-456",
SubscriptionID: "sub-id-789",
TenantID: "tenant-id-123",
ActiveDirectoryEndpoint: "https://login.microsoftonline.com",
ResourceManagerEndpoint: "https://management.azure.com/",
GraphResourceID: "https://graph.windows.net/",
SQLManagementEndpoint: "https://management.core.windows.net:8443/",
GalleryEndpoint: "https://gallery.azure.com/",
ManagementEndpoint: "https://management.core.windows.net/",
}
)
func TestGetClientSetup(t *testing.T) {
os.Setenv("AZURE_AUTH_LOCATION", filepath.Join(getCredsPath(), "credsutf16le.json"))
setup, err := GetClientSetup("https://management.azure.com")
if err != nil {
t.Logf("GetClientSetup failed, got error %v", err)
t.Fail()
}
if setup.BaseURI != "https://management.azure.com/" {
t.Logf("auth.BaseURI not set correctly, expected 'https://management.azure.com/', got '%s'", setup.BaseURI)
t.Fail()
}
if !reflect.DeepEqual(expectedFile, setup.File) {
t.Logf("auth.File not set correctly, expected %v, got %v", expectedFile, setup.File)
t.Fail()
}
if setup.BearerAuthorizer == nil {
t.Log("auth.Authorizer not set correctly, got nil")
t.Fail()
}
}
func TestDecodeAndUnmarshal(t *testing.T) {
tests := []string{
"credsutf8.json",
"credsutf16le.json",
"credsutf16be.json",
}
creds := getCredsPath()
for _, test := range tests {
b, err := ioutil.ReadFile(filepath.Join(creds, test))
if err != nil {
t.Logf("error reading file '%s': %s", test, err)
t.Fail()
}
decoded, err := decode(b)
if err != nil {
t.Logf("error decoding file '%s': %s", test, err)
t.Fail()
}
var got File
err = json.Unmarshal(decoded, &got)
if err != nil {
t.Logf("error unmarshaling file '%s': %s", test, err)
t.Fail()
}
if !reflect.DeepEqual(expectedFile, got) {
t.Logf("unmarshaled map expected %v, got %v", expectedFile, got)
t.Fail()
}
}
}
func getCredsPath() string {
gopath := os.Getenv("GOPATH")
return filepath.Join(gopath, "src", "github.com", "Azure", "go-autorest", "testdata")
}
func areMapsEqual(a, b map[string]string) bool {
if len(a) != len(b) {
return false
}
for k := range a {
if a[k] != b[k] {
return false
}
}
return true
}

View File

@@ -1,200 +0,0 @@
/*
Package azure provides Azure-specific implementations used with AutoRest.
See the included examples for more detail.
*/
package azure
// Copyright 2017 Microsoft Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
import (
"encoding/json"
"fmt"
"io/ioutil"
"net/http"
"strconv"
"github.com/Azure/go-autorest/autorest"
)
const (
// HeaderClientID is the Azure extension header to set a user-specified request ID.
HeaderClientID = "x-ms-client-request-id"
// HeaderReturnClientID is the Azure extension header to set if the user-specified request ID
// should be included in the response.
HeaderReturnClientID = "x-ms-return-client-request-id"
// HeaderRequestID is the Azure extension header of the service generated request ID returned
// in the response.
HeaderRequestID = "x-ms-request-id"
)
// ServiceError encapsulates the error response from an Azure service.
type ServiceError struct {
Code string `json:"code"`
Message string `json:"message"`
Details *[]interface{} `json:"details"`
}
func (se ServiceError) Error() string {
if se.Details != nil {
d, err := json.Marshal(*(se.Details))
if err != nil {
return fmt.Sprintf("Code=%q Message=%q Details=%v", se.Code, se.Message, *se.Details)
}
return fmt.Sprintf("Code=%q Message=%q Details=%v", se.Code, se.Message, string(d))
}
return fmt.Sprintf("Code=%q Message=%q", se.Code, se.Message)
}
// RequestError describes an error response returned by Azure service.
type RequestError struct {
autorest.DetailedError
// The error returned by the Azure service.
ServiceError *ServiceError `json:"error"`
// The request id (from the x-ms-request-id-header) of the request.
RequestID string
}
// Error returns a human-friendly error message from service error.
func (e RequestError) Error() string {
return fmt.Sprintf("autorest/azure: Service returned an error. Status=%v %v",
e.StatusCode, e.ServiceError)
}
// IsAzureError returns true if the passed error is an Azure Service error; false otherwise.
func IsAzureError(e error) bool {
_, ok := e.(*RequestError)
return ok
}
// NewErrorWithError creates a new Error conforming object from the
// passed packageType, method, statusCode of the given resp (UndefinedStatusCode
// if resp is nil), message, and original error. message is treated as a format
// string to which the optional args apply.
func NewErrorWithError(original error, packageType string, method string, resp *http.Response, message string, args ...interface{}) RequestError {
if v, ok := original.(*RequestError); ok {
return *v
}
statusCode := autorest.UndefinedStatusCode
if resp != nil {
statusCode = resp.StatusCode
}
return RequestError{
DetailedError: autorest.DetailedError{
Original: original,
PackageType: packageType,
Method: method,
StatusCode: statusCode,
Message: fmt.Sprintf(message, args...),
},
}
}
// WithReturningClientID returns a PrepareDecorator that adds an HTTP extension header of
// x-ms-client-request-id whose value is the passed, undecorated UUID (e.g.,
// "0F39878C-5F76-4DB8-A25D-61D2C193C3CA"). It also sets the x-ms-return-client-request-id
// header to true such that UUID accompanies the http.Response.
func WithReturningClientID(uuid string) autorest.PrepareDecorator {
preparer := autorest.CreatePreparer(
WithClientID(uuid),
WithReturnClientID(true))
return func(p autorest.Preparer) autorest.Preparer {
return autorest.PreparerFunc(func(r *http.Request) (*http.Request, error) {
r, err := p.Prepare(r)
if err != nil {
return r, err
}
return preparer.Prepare(r)
})
}
}
// WithClientID returns a PrepareDecorator that adds an HTTP extension header of
// x-ms-client-request-id whose value is passed, undecorated UUID (e.g.,
// "0F39878C-5F76-4DB8-A25D-61D2C193C3CA").
func WithClientID(uuid string) autorest.PrepareDecorator {
return autorest.WithHeader(HeaderClientID, uuid)
}
// WithReturnClientID returns a PrepareDecorator that adds an HTTP extension header of
// x-ms-return-client-request-id whose boolean value indicates if the value of the
// x-ms-client-request-id header should be included in the http.Response.
func WithReturnClientID(b bool) autorest.PrepareDecorator {
return autorest.WithHeader(HeaderReturnClientID, strconv.FormatBool(b))
}
// ExtractClientID extracts the client identifier from the x-ms-client-request-id header set on the
// http.Request sent to the service (and returned in the http.Response)
func ExtractClientID(resp *http.Response) string {
return autorest.ExtractHeaderValue(HeaderClientID, resp)
}
// ExtractRequestID extracts the Azure server generated request identifier from the
// x-ms-request-id header.
func ExtractRequestID(resp *http.Response) string {
return autorest.ExtractHeaderValue(HeaderRequestID, resp)
}
// WithErrorUnlessStatusCode returns a RespondDecorator that emits an
// azure.RequestError by reading the response body unless the response HTTP status code
// is among the set passed.
//
// If there is a chance service may return responses other than the Azure error
// format and the response cannot be parsed into an error, a decoding error will
// be returned containing the response body. In any case, the Responder will
// return an error if the status code is not satisfied.
//
// If this Responder returns an error, the response body will be replaced with
// an in-memory reader, which needs no further closing.
func WithErrorUnlessStatusCode(codes ...int) autorest.RespondDecorator {
return func(r autorest.Responder) autorest.Responder {
return autorest.ResponderFunc(func(resp *http.Response) error {
err := r.Respond(resp)
if err == nil && !autorest.ResponseHasStatusCode(resp, codes...) {
var e RequestError
defer resp.Body.Close()
// Copy and replace the Body in case it does not contain an error object.
// This will leave the Body available to the caller.
b, decodeErr := autorest.CopyAndDecode(autorest.EncodedAsJSON, resp.Body, &e)
resp.Body = ioutil.NopCloser(&b)
if decodeErr != nil {
return fmt.Errorf("autorest/azure: error response cannot be parsed: %q error: %v", b.String(), decodeErr)
} else if e.ServiceError == nil {
// Check if error is unwrapped ServiceError
if err := json.Unmarshal(b.Bytes(), &e.ServiceError); err != nil || e.ServiceError.Message == "" {
e.ServiceError = &ServiceError{
Code: "Unknown",
Message: "Unknown service error",
}
}
}
e.RequestID = ExtractRequestID(resp)
if e.StatusCode == nil {
e.StatusCode = resp.StatusCode
}
err = &e
}
return err
})
}
}

View File

@@ -1,513 +0,0 @@
package azure
// Copyright 2017 Microsoft Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
import (
"encoding/json"
"fmt"
"io/ioutil"
"net/http"
"reflect"
"strconv"
"testing"
"time"
"github.com/Azure/go-autorest/autorest"
"github.com/Azure/go-autorest/autorest/mocks"
)
const (
headerAuthorization = "Authorization"
longDelay = 5 * time.Second
retryDelay = 10 * time.Millisecond
testLogPrefix = "azure:"
)
// Use a Client Inspector to set the request identifier.
func ExampleWithClientID() {
uuid := "71FDB9F4-5E49-4C12-B266-DE7B4FD999A6"
req, _ := autorest.Prepare(&http.Request{},
autorest.AsGet(),
autorest.WithBaseURL("https://microsoft.com/a/b/c/"))
c := autorest.Client{Sender: mocks.NewSender()}
c.RequestInspector = WithReturningClientID(uuid)
autorest.SendWithSender(c, req)
fmt.Printf("Inspector added the %s header with the value %s\n",
HeaderClientID, req.Header.Get(HeaderClientID))
fmt.Printf("Inspector added the %s header with the value %s\n",
HeaderReturnClientID, req.Header.Get(HeaderReturnClientID))
// Output:
// Inspector added the x-ms-client-request-id header with the value 71FDB9F4-5E49-4C12-B266-DE7B4FD999A6
// Inspector added the x-ms-return-client-request-id header with the value true
}
func TestWithReturningClientIDReturnsError(t *testing.T) {
var errIn error
uuid := "71FDB9F4-5E49-4C12-B266-DE7B4FD999A6"
_, errOut := autorest.Prepare(&http.Request{},
withErrorPrepareDecorator(&errIn),
WithReturningClientID(uuid))
if errOut == nil || errIn != errOut {
t.Fatalf("azure: WithReturningClientID failed to exit early when receiving an error -- expected (%v), received (%v)",
errIn, errOut)
}
}
func TestWithClientID(t *testing.T) {
uuid := "71FDB9F4-5E49-4C12-B266-DE7B4FD999A6"
req, _ := autorest.Prepare(&http.Request{},
WithClientID(uuid))
if req.Header.Get(HeaderClientID) != uuid {
t.Fatalf("azure: WithClientID failed to set %s -- expected %s, received %s",
HeaderClientID, uuid, req.Header.Get(HeaderClientID))
}
}
func TestWithReturnClientID(t *testing.T) {
b := false
req, _ := autorest.Prepare(&http.Request{},
WithReturnClientID(b))
if req.Header.Get(HeaderReturnClientID) != strconv.FormatBool(b) {
t.Fatalf("azure: WithReturnClientID failed to set %s -- expected %s, received %s",
HeaderClientID, strconv.FormatBool(b), req.Header.Get(HeaderClientID))
}
}
func TestExtractClientID(t *testing.T) {
uuid := "71FDB9F4-5E49-4C12-B266-DE7B4FD999A6"
resp := mocks.NewResponse()
mocks.SetResponseHeader(resp, HeaderClientID, uuid)
if ExtractClientID(resp) != uuid {
t.Fatalf("azure: ExtractClientID failed to extract the %s -- expected %s, received %s",
HeaderClientID, uuid, ExtractClientID(resp))
}
}
func TestExtractRequestID(t *testing.T) {
uuid := "71FDB9F4-5E49-4C12-B266-DE7B4FD999A6"
resp := mocks.NewResponse()
mocks.SetResponseHeader(resp, HeaderRequestID, uuid)
if ExtractRequestID(resp) != uuid {
t.Fatalf("azure: ExtractRequestID failed to extract the %s -- expected %s, received %s",
HeaderRequestID, uuid, ExtractRequestID(resp))
}
}
func TestIsAzureError_ReturnsTrueForAzureError(t *testing.T) {
if !IsAzureError(&RequestError{}) {
t.Fatalf("azure: IsAzureError failed to return true for an Azure Service error")
}
}
func TestIsAzureError_ReturnsFalseForNonAzureError(t *testing.T) {
if IsAzureError(fmt.Errorf("An Error")) {
t.Fatalf("azure: IsAzureError return true for an non-Azure Service error")
}
}
func TestNewErrorWithError_UsesReponseStatusCode(t *testing.T) {
e := NewErrorWithError(fmt.Errorf("Error"), "packageType", "method", mocks.NewResponseWithStatus("Forbidden", http.StatusForbidden), "message")
if e.StatusCode != http.StatusForbidden {
t.Fatalf("azure: NewErrorWithError failed to use the Status Code of the passed Response -- expected %v, received %v", http.StatusForbidden, e.StatusCode)
}
}
func TestNewErrorWithError_ReturnsUnwrappedError(t *testing.T) {
e1 := RequestError{}
e1.ServiceError = &ServiceError{Code: "42", Message: "A Message"}
e1.StatusCode = 200
e1.RequestID = "A RequestID"
e2 := NewErrorWithError(&e1, "packageType", "method", nil, "message")
if !reflect.DeepEqual(e1, e2) {
t.Fatalf("azure: NewErrorWithError wrapped an RequestError -- expected %T, received %T", e1, e2)
}
}
func TestNewErrorWithError_WrapsAnError(t *testing.T) {
e1 := fmt.Errorf("Inner Error")
var e2 interface{} = NewErrorWithError(e1, "packageType", "method", nil, "message")
if _, ok := e2.(RequestError); !ok {
t.Fatalf("azure: NewErrorWithError failed to wrap a standard error -- received %T", e2)
}
}
func TestWithErrorUnlessStatusCode_NotAnAzureError(t *testing.T) {
body := `<html>
<head>
<title>IIS Error page</title>
</head>
<body>Some non-JSON error page</body>
</html>`
r := mocks.NewResponseWithContent(body)
r.Request = mocks.NewRequest()
r.StatusCode = http.StatusBadRequest
r.Status = http.StatusText(r.StatusCode)
err := autorest.Respond(r,
WithErrorUnlessStatusCode(http.StatusOK),
autorest.ByClosing())
ok, _ := err.(*RequestError)
if ok != nil {
t.Fatalf("azure: azure.RequestError returned from malformed response: %v", err)
}
// the error body should still be there
defer r.Body.Close()
b, err := ioutil.ReadAll(r.Body)
if err != nil {
t.Fatal(err)
}
if string(b) != body {
t.Fatalf("response body is wrong. got=%q exptected=%q", string(b), body)
}
}
func TestWithErrorUnlessStatusCode_FoundAzureErrorWithoutDetails(t *testing.T) {
j := `{
"error": {
"code": "InternalError",
"message": "Azure is having trouble right now."
}
}`
uuid := "71FDB9F4-5E49-4C12-B266-DE7B4FD999A6"
r := mocks.NewResponseWithContent(j)
mocks.SetResponseHeader(r, HeaderRequestID, uuid)
r.Request = mocks.NewRequest()
r.StatusCode = http.StatusInternalServerError
r.Status = http.StatusText(r.StatusCode)
err := autorest.Respond(r,
WithErrorUnlessStatusCode(http.StatusOK),
autorest.ByClosing())
if err == nil {
t.Fatalf("azure: returned nil error for proper error response")
}
azErr, ok := err.(*RequestError)
if !ok {
t.Fatalf("azure: returned error is not azure.RequestError: %T", err)
}
expected := "autorest/azure: Service returned an error. Status=500 Code=\"InternalError\" Message=\"Azure is having trouble right now.\""
if !reflect.DeepEqual(expected, azErr.Error()) {
t.Fatalf("azure: service error is not unmarshaled properly.\nexpected=%v\ngot=%v", expected, azErr.Error())
}
if expected := http.StatusInternalServerError; azErr.StatusCode != expected {
t.Fatalf("azure: got wrong StatusCode=%d Expected=%d", azErr.StatusCode, expected)
}
if expected := uuid; azErr.RequestID != expected {
t.Fatalf("azure: wrong request ID in error. expected=%q; got=%q", expected, azErr.RequestID)
}
_ = azErr.Error()
// the error body should still be there
defer r.Body.Close()
b, err := ioutil.ReadAll(r.Body)
if err != nil {
t.Fatal(err)
}
if string(b) != j {
t.Fatalf("response body is wrong. got=%q expected=%q", string(b), j)
}
}
func TestWithErrorUnlessStatusCode_FoundAzureErrorWithDetails(t *testing.T) {
j := `{
"error": {
"code": "InternalError",
"message": "Azure is having trouble right now.",
"details": [{"code": "conflict1", "message":"error message1"},
{"code": "conflict2", "message":"error message2"}]
}
}`
uuid := "71FDB9F4-5E49-4C12-B266-DE7B4FD999A6"
r := mocks.NewResponseWithContent(j)
mocks.SetResponseHeader(r, HeaderRequestID, uuid)
r.Request = mocks.NewRequest()
r.StatusCode = http.StatusInternalServerError
r.Status = http.StatusText(r.StatusCode)
err := autorest.Respond(r,
WithErrorUnlessStatusCode(http.StatusOK),
autorest.ByClosing())
if err == nil {
t.Fatalf("azure: returned nil error for proper error response")
}
azErr, ok := err.(*RequestError)
if !ok {
t.Fatalf("azure: returned error is not azure.RequestError: %T", err)
}
if expected := "InternalError"; azErr.ServiceError.Code != expected {
t.Fatalf("azure: wrong error code. expected=%q; got=%q", expected, azErr.ServiceError.Code)
}
if azErr.ServiceError.Message == "" {
t.Fatalf("azure: error message is not unmarshaled properly")
}
b, _ := json.Marshal(*azErr.ServiceError.Details)
if string(b) != `[{"code":"conflict1","message":"error message1"},{"code":"conflict2","message":"error message2"}]` {
t.Fatalf("azure: error details is not unmarshaled properly")
}
if expected := http.StatusInternalServerError; azErr.StatusCode != expected {
t.Fatalf("azure: got wrong StatusCode=%v Expected=%d", azErr.StatusCode, expected)
}
if expected := uuid; azErr.RequestID != expected {
t.Fatalf("azure: wrong request ID in error. expected=%q; got=%q", expected, azErr.RequestID)
}
_ = azErr.Error()
// the error body should still be there
defer r.Body.Close()
b, err = ioutil.ReadAll(r.Body)
if err != nil {
t.Fatal(err)
}
if string(b) != j {
t.Fatalf("response body is wrong. got=%q expected=%q", string(b), j)
}
}
func TestWithErrorUnlessStatusCode_NoAzureError(t *testing.T) {
j := `{
"Status":"NotFound"
}`
uuid := "71FDB9F4-5E49-4C12-B266-DE7B4FD999A6"
r := mocks.NewResponseWithContent(j)
mocks.SetResponseHeader(r, HeaderRequestID, uuid)
r.Request = mocks.NewRequest()
r.StatusCode = http.StatusInternalServerError
r.Status = http.StatusText(r.StatusCode)
err := autorest.Respond(r,
WithErrorUnlessStatusCode(http.StatusOK),
autorest.ByClosing())
if err == nil {
t.Fatalf("azure: returned nil error for proper error response")
}
azErr, ok := err.(*RequestError)
if !ok {
t.Fatalf("azure: returned error is not azure.RequestError: %T", err)
}
expected := &ServiceError{
Code: "Unknown",
Message: "Unknown service error",
}
if !reflect.DeepEqual(expected, azErr.ServiceError) {
t.Fatalf("azure: service error is not unmarshaled properly. expected=%q\ngot=%q", expected, azErr.ServiceError)
}
if expected := http.StatusInternalServerError; azErr.StatusCode != expected {
t.Fatalf("azure: got wrong StatusCode=%v Expected=%d", azErr.StatusCode, expected)
}
if expected := uuid; azErr.RequestID != expected {
t.Fatalf("azure: wrong request ID in error. expected=%q; got=%q", expected, azErr.RequestID)
}
_ = azErr.Error()
// the error body should still be there
defer r.Body.Close()
b, err := ioutil.ReadAll(r.Body)
if err != nil {
t.Fatal(err)
}
if string(b) != j {
t.Fatalf("response body is wrong. got=%q expected=%q", string(b), j)
}
}
func TestWithErrorUnlessStatusCode_UnwrappedError(t *testing.T) {
j := `{
"target": null,
"code": "InternalError",
"message": "Azure is having trouble right now.",
"details": [{"code": "conflict1", "message":"error message1"},
{"code": "conflict2", "message":"error message2"}],
"innererror": []
}`
uuid := "71FDB9F4-5E49-4C12-B266-DE7B4FD999A6"
r := mocks.NewResponseWithContent(j)
mocks.SetResponseHeader(r, HeaderRequestID, uuid)
r.Request = mocks.NewRequest()
r.StatusCode = http.StatusInternalServerError
r.Status = http.StatusText(r.StatusCode)
err := autorest.Respond(r,
WithErrorUnlessStatusCode(http.StatusOK),
autorest.ByClosing())
if err == nil {
t.Fatal("azure: returned nil error for proper error response")
}
azErr, ok := err.(*RequestError)
if !ok {
t.Fatalf("returned error is not azure.RequestError: %T", err)
}
if expected := http.StatusInternalServerError; azErr.StatusCode != expected {
t.Logf("Incorrect StatusCode got: %v want: %d", azErr.StatusCode, expected)
t.Fail()
}
if expected := "Azure is having trouble right now."; azErr.ServiceError.Message != expected {
t.Logf("Incorrect Message\n\tgot: %q\n\twant: %q", azErr.Message, expected)
t.Fail()
}
if expected := uuid; azErr.RequestID != expected {
t.Logf("Incorrect request ID\n\tgot: %q\n\twant: %q", azErr.RequestID, expected)
t.Fail()
}
expectedServiceErrorDetails := `[{"code":"conflict1","message":"error message1"},{"code":"conflict2","message":"error message2"}]`
if azErr.ServiceError == nil {
t.Logf("`ServiceError` was nil when it shouldn't have been.")
t.Fail()
} else if azErr.ServiceError.Details == nil {
t.Logf("`ServiceError.Details` was nil when it should have been %q", expectedServiceErrorDetails)
t.Fail()
} else if details, _ := json.Marshal(*azErr.ServiceError.Details); expectedServiceErrorDetails != string(details) {
t.Logf("Error detaisl was not unmarshaled properly.\n\tgot: %q\n\twant: %q", string(details), expectedServiceErrorDetails)
t.Fail()
}
// the error body should still be there
defer r.Body.Close()
b, err := ioutil.ReadAll(r.Body)
if err != nil {
t.Error(err)
}
if string(b) != j {
t.Fatalf("response body is wrong. got=%q expected=%q", string(b), j)
}
}
func TestRequestErrorString_WithError(t *testing.T) {
j := `{
"error": {
"code": "InternalError",
"message": "Conflict",
"details": [{"code": "conflict1", "message":"error message1"}]
}
}`
uuid := "71FDB9F4-5E49-4C12-B266-DE7B4FD999A6"
r := mocks.NewResponseWithContent(j)
mocks.SetResponseHeader(r, HeaderRequestID, uuid)
r.Request = mocks.NewRequest()
r.StatusCode = http.StatusInternalServerError
r.Status = http.StatusText(r.StatusCode)
err := autorest.Respond(r,
WithErrorUnlessStatusCode(http.StatusOK),
autorest.ByClosing())
if err == nil {
t.Fatalf("azure: returned nil error for proper error response")
}
azErr, _ := err.(*RequestError)
expected := "autorest/azure: Service returned an error. Status=500 Code=\"InternalError\" Message=\"Conflict\" Details=[{\"code\":\"conflict1\",\"message\":\"error message1\"}]"
if expected != azErr.Error() {
t.Fatalf("azure: send wrong RequestError.\nexpected=%v\ngot=%v", expected, azErr.Error())
}
}
func withErrorPrepareDecorator(e *error) autorest.PrepareDecorator {
return func(p autorest.Preparer) autorest.Preparer {
return autorest.PreparerFunc(func(r *http.Request) (*http.Request, error) {
*e = fmt.Errorf("azure: Faux Prepare Error")
return r, *e
})
}
}
func withAsyncResponseDecorator(n int) autorest.SendDecorator {
i := 0
return func(s autorest.Sender) autorest.Sender {
return autorest.SenderFunc(func(r *http.Request) (*http.Response, error) {
resp, err := s.Do(r)
if err == nil {
if i < n {
resp.StatusCode = http.StatusCreated
resp.Header = http.Header{}
resp.Header.Add(http.CanonicalHeaderKey(headerAsyncOperation), mocks.TestURL)
i++
} else {
resp.StatusCode = http.StatusOK
resp.Header.Del(http.CanonicalHeaderKey(headerAsyncOperation))
}
}
return resp, err
})
}
}
type mockAuthorizer struct{}
func (ma mockAuthorizer) WithAuthorization() autorest.PrepareDecorator {
return autorest.WithHeader(headerAuthorization, mocks.TestAuthorizationHeader)
}
type mockFailingAuthorizer struct{}
func (mfa mockFailingAuthorizer) WithAuthorization() autorest.PrepareDecorator {
return func(p autorest.Preparer) autorest.Preparer {
return autorest.PreparerFunc(func(r *http.Request) (*http.Request, error) {
return r, fmt.Errorf("ERROR: mockFailingAuthorizer returned expected error")
})
}
}
type mockInspector struct {
wasInvoked bool
}
func (mi *mockInspector) WithInspection() autorest.PrepareDecorator {
return func(p autorest.Preparer) autorest.Preparer {
return autorest.PreparerFunc(func(r *http.Request) (*http.Request, error) {
mi.wasInvoked = true
return p.Prepare(r)
})
}
}
func (mi *mockInspector) ByInspecting() autorest.RespondDecorator {
return func(r autorest.Responder) autorest.Responder {
return autorest.ResponderFunc(func(resp *http.Response) error {
mi.wasInvoked = true
return r.Respond(resp)
})
}
}

View File

@@ -1,65 +0,0 @@
package cli
// Copyright 2017 Microsoft Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
import (
"bytes"
"encoding/json"
"fmt"
"io/ioutil"
"github.com/dimchansky/utfbom"
"github.com/mitchellh/go-homedir"
)
// Profile represents a Profile from the Azure CLI
type Profile struct {
InstallationID string `json:"installationId"`
Subscriptions []Subscription `json:"subscriptions"`
}
// Subscription represents a Subscription from the Azure CLI
type Subscription struct {
EnvironmentName string `json:"environmentName"`
ID string `json:"id"`
IsDefault bool `json:"isDefault"`
Name string `json:"name"`
State string `json:"state"`
TenantID string `json:"tenantId"`
}
// ProfilePath returns the path where the Azure Profile is stored from the Azure CLI
func ProfilePath() (string, error) {
return homedir.Expand("~/.azure/azureProfile.json")
}
// LoadProfile restores a Profile object from a file located at 'path'.
func LoadProfile(path string) (result Profile, err error) {
var contents []byte
contents, err = ioutil.ReadFile(path)
if err != nil {
err = fmt.Errorf("failed to open file (%s) while loading token: %v", path, err)
return
}
reader := utfbom.SkipOnly(bytes.NewReader(contents))
dec := json.NewDecoder(reader)
if err = dec.Decode(&result); err != nil {
err = fmt.Errorf("failed to decode contents of file (%s) into a Profile representation: %v", path, err)
return
}
return
}

View File

@@ -1,114 +0,0 @@
package cli
// Copyright 2017 Microsoft Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
import (
"encoding/json"
"fmt"
"os"
"strconv"
"time"
"github.com/Azure/go-autorest/autorest/adal"
"github.com/Azure/go-autorest/autorest/date"
"github.com/mitchellh/go-homedir"
)
// Token represents an AccessToken from the Azure CLI
type Token struct {
AccessToken string `json:"accessToken"`
Authority string `json:"_authority"`
ClientID string `json:"_clientId"`
ExpiresOn string `json:"expiresOn"`
IdentityProvider string `json:"identityProvider"`
IsMRRT bool `json:"isMRRT"`
RefreshToken string `json:"refreshToken"`
Resource string `json:"resource"`
TokenType string `json:"tokenType"`
UserID string `json:"userId"`
}
// ToADALToken converts an Azure CLI `Token`` to an `adal.Token``
func (t Token) ToADALToken() (converted adal.Token, err error) {
tokenExpirationDate, err := ParseExpirationDate(t.ExpiresOn)
if err != nil {
err = fmt.Errorf("Error parsing Token Expiration Date %q: %+v", t.ExpiresOn, err)
return
}
difference := tokenExpirationDate.Sub(date.UnixEpoch())
converted = adal.Token{
AccessToken: t.AccessToken,
Type: t.TokenType,
ExpiresIn: "3600",
ExpiresOn: strconv.Itoa(int(difference.Seconds())),
RefreshToken: t.RefreshToken,
Resource: t.Resource,
}
return
}
// AccessTokensPath returns the path where access tokens are stored from the Azure CLI
// TODO(#199): add unit test.
func AccessTokensPath() (string, error) {
// Azure-CLI allows user to customize the path of access tokens thorugh environment variable.
var accessTokenPath = os.Getenv("AZURE_ACCESS_TOKEN_FILE")
var err error
// Fallback logic to default path on non-cloud-shell environment.
// TODO(#200): remove the dependency on hard-coding path.
if accessTokenPath == "" {
accessTokenPath, err = homedir.Expand("~/.azure/accessTokens.json")
}
return accessTokenPath, err
}
// ParseExpirationDate parses either a Azure CLI or CloudShell date into a time object
func ParseExpirationDate(input string) (*time.Time, error) {
// CloudShell (and potentially the Azure CLI in future)
expirationDate, cloudShellErr := time.Parse(time.RFC3339, input)
if cloudShellErr != nil {
// Azure CLI (Python) e.g. 2017-08-31 19:48:57.998857 (plus the local timezone)
const cliFormat = "2006-01-02 15:04:05.999999"
expirationDate, cliErr := time.ParseInLocation(cliFormat, input, time.Local)
if cliErr == nil {
return &expirationDate, nil
}
return nil, fmt.Errorf("Error parsing expiration date %q.\n\nCloudShell Error: \n%+v\n\nCLI Error:\n%+v", input, cloudShellErr, cliErr)
}
return &expirationDate, nil
}
// LoadTokens restores a set of Token objects from a file located at 'path'.
func LoadTokens(path string) ([]Token, error) {
file, err := os.Open(path)
if err != nil {
return nil, fmt.Errorf("failed to open file (%s) while loading token: %v", path, err)
}
defer file.Close()
var tokens []Token
dec := json.NewDecoder(file)
if err = dec.Decode(&tokens); err != nil {
return nil, fmt.Errorf("failed to decode contents of file (%s) into a `cli.Token` representation: %v", path, err)
}
return tokens, nil
}

View File

@@ -1,176 +0,0 @@
package azure
// Copyright 2017 Microsoft Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
import (
"encoding/json"
"fmt"
"io/ioutil"
"os"
"strings"
)
// EnvironmentFilepathName captures the name of the environment variable containing the path to the file
// to be used while populating the Azure Environment.
const EnvironmentFilepathName = "AZURE_ENVIRONMENT_FILEPATH"
var environments = map[string]Environment{
"AZURECHINACLOUD": ChinaCloud,
"AZUREGERMANCLOUD": GermanCloud,
"AZUREPUBLICCLOUD": PublicCloud,
"AZUREUSGOVERNMENTCLOUD": USGovernmentCloud,
}
// Environment represents a set of endpoints for each of Azure's Clouds.
type Environment struct {
Name string `json:"name"`
ManagementPortalURL string `json:"managementPortalURL"`
PublishSettingsURL string `json:"publishSettingsURL"`
ServiceManagementEndpoint string `json:"serviceManagementEndpoint"`
ResourceManagerEndpoint string `json:"resourceManagerEndpoint"`
ActiveDirectoryEndpoint string `json:"activeDirectoryEndpoint"`
GalleryEndpoint string `json:"galleryEndpoint"`
KeyVaultEndpoint string `json:"keyVaultEndpoint"`
GraphEndpoint string `json:"graphEndpoint"`
StorageEndpointSuffix string `json:"storageEndpointSuffix"`
SQLDatabaseDNSSuffix string `json:"sqlDatabaseDNSSuffix"`
TrafficManagerDNSSuffix string `json:"trafficManagerDNSSuffix"`
KeyVaultDNSSuffix string `json:"keyVaultDNSSuffix"`
ServiceBusEndpointSuffix string `json:"serviceBusEndpointSuffix"`
ServiceManagementVMDNSSuffix string `json:"serviceManagementVMDNSSuffix"`
ResourceManagerVMDNSSuffix string `json:"resourceManagerVMDNSSuffix"`
ContainerRegistryDNSSuffix string `json:"containerRegistryDNSSuffix"`
}
var (
// PublicCloud is the default public Azure cloud environment
PublicCloud = Environment{
Name: "AzurePublicCloud",
ManagementPortalURL: "https://manage.windowsazure.com/",
PublishSettingsURL: "https://manage.windowsazure.com/publishsettings/index",
ServiceManagementEndpoint: "https://management.core.windows.net/",
ResourceManagerEndpoint: "https://management.azure.com/",
ActiveDirectoryEndpoint: "https://login.microsoftonline.com/",
GalleryEndpoint: "https://gallery.azure.com/",
KeyVaultEndpoint: "https://vault.azure.net/",
GraphEndpoint: "https://graph.windows.net/",
StorageEndpointSuffix: "core.windows.net",
SQLDatabaseDNSSuffix: "database.windows.net",
TrafficManagerDNSSuffix: "trafficmanager.net",
KeyVaultDNSSuffix: "vault.azure.net",
ServiceBusEndpointSuffix: "servicebus.azure.com",
ServiceManagementVMDNSSuffix: "cloudapp.net",
ResourceManagerVMDNSSuffix: "cloudapp.azure.com",
ContainerRegistryDNSSuffix: "azurecr.io",
}
// USGovernmentCloud is the cloud environment for the US Government
USGovernmentCloud = Environment{
Name: "AzureUSGovernmentCloud",
ManagementPortalURL: "https://manage.windowsazure.us/",
PublishSettingsURL: "https://manage.windowsazure.us/publishsettings/index",
ServiceManagementEndpoint: "https://management.core.usgovcloudapi.net/",
ResourceManagerEndpoint: "https://management.usgovcloudapi.net/",
ActiveDirectoryEndpoint: "https://login.microsoftonline.com/",
GalleryEndpoint: "https://gallery.usgovcloudapi.net/",
KeyVaultEndpoint: "https://vault.usgovcloudapi.net/",
GraphEndpoint: "https://graph.usgovcloudapi.net/",
StorageEndpointSuffix: "core.usgovcloudapi.net",
SQLDatabaseDNSSuffix: "database.usgovcloudapi.net",
TrafficManagerDNSSuffix: "usgovtrafficmanager.net",
KeyVaultDNSSuffix: "vault.usgovcloudapi.net",
ServiceBusEndpointSuffix: "servicebus.usgovcloudapi.net",
ServiceManagementVMDNSSuffix: "usgovcloudapp.net",
ResourceManagerVMDNSSuffix: "cloudapp.windowsazure.us",
ContainerRegistryDNSSuffix: "azurecr.io",
}
// ChinaCloud is the cloud environment operated in China
ChinaCloud = Environment{
Name: "AzureChinaCloud",
ManagementPortalURL: "https://manage.chinacloudapi.com/",
PublishSettingsURL: "https://manage.chinacloudapi.com/publishsettings/index",
ServiceManagementEndpoint: "https://management.core.chinacloudapi.cn/",
ResourceManagerEndpoint: "https://management.chinacloudapi.cn/",
ActiveDirectoryEndpoint: "https://login.chinacloudapi.cn/",
GalleryEndpoint: "https://gallery.chinacloudapi.cn/",
KeyVaultEndpoint: "https://vault.azure.cn/",
GraphEndpoint: "https://graph.chinacloudapi.cn/",
StorageEndpointSuffix: "core.chinacloudapi.cn",
SQLDatabaseDNSSuffix: "database.chinacloudapi.cn",
TrafficManagerDNSSuffix: "trafficmanager.cn",
KeyVaultDNSSuffix: "vault.azure.cn",
ServiceBusEndpointSuffix: "servicebus.chinacloudapi.net",
ServiceManagementVMDNSSuffix: "chinacloudapp.cn",
ResourceManagerVMDNSSuffix: "cloudapp.azure.cn",
ContainerRegistryDNSSuffix: "azurecr.io",
}
// GermanCloud is the cloud environment operated in Germany
GermanCloud = Environment{
Name: "AzureGermanCloud",
ManagementPortalURL: "http://portal.microsoftazure.de/",
PublishSettingsURL: "https://manage.microsoftazure.de/publishsettings/index",
ServiceManagementEndpoint: "https://management.core.cloudapi.de/",
ResourceManagerEndpoint: "https://management.microsoftazure.de/",
ActiveDirectoryEndpoint: "https://login.microsoftonline.de/",
GalleryEndpoint: "https://gallery.cloudapi.de/",
KeyVaultEndpoint: "https://vault.microsoftazure.de/",
GraphEndpoint: "https://graph.cloudapi.de/",
StorageEndpointSuffix: "core.cloudapi.de",
SQLDatabaseDNSSuffix: "database.cloudapi.de",
TrafficManagerDNSSuffix: "azuretrafficmanager.de",
KeyVaultDNSSuffix: "vault.microsoftazure.de",
ServiceBusEndpointSuffix: "servicebus.cloudapi.de",
ServiceManagementVMDNSSuffix: "azurecloudapp.de",
ResourceManagerVMDNSSuffix: "cloudapp.microsoftazure.de",
ContainerRegistryDNSSuffix: "azurecr.io",
}
)
// EnvironmentFromName returns an Environment based on the common name specified.
func EnvironmentFromName(name string) (Environment, error) {
// IMPORTANT
// As per @radhikagupta5:
// This is technical debt, fundamentally here because Kubernetes is not currently accepting
// contributions to the providers. Once that is an option, the provider should be updated to
// directly call `EnvironmentFromFile`. Until then, we rely on dispatching Azure Stack environment creation
// from this method based on the name that is provided to us.
if strings.EqualFold(name, "AZURESTACKCLOUD") {
return EnvironmentFromFile(os.Getenv(EnvironmentFilepathName))
}
name = strings.ToUpper(name)
env, ok := environments[name]
if !ok {
return env, fmt.Errorf("autorest/azure: There is no cloud environment matching the name %q", name)
}
return env, nil
}
// EnvironmentFromFile loads an Environment from a configuration file available on disk.
// This function is particularly useful in the Hybrid Cloud model, where one must define their own
// endpoints.
func EnvironmentFromFile(location string) (unmarshaled Environment, err error) {
fileContents, err := ioutil.ReadFile(location)
if err != nil {
return
}
err = json.Unmarshal(fileContents, &unmarshaled)
return
}

View File

@@ -1,284 +0,0 @@
// test
package azure
// Copyright 2017 Microsoft Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
import (
"encoding/json"
"os"
"path"
"path/filepath"
"runtime"
"testing"
)
// This correlates to the expected contents of ./testdata/test_environment_1.json
var testEnvironment1 = Environment{
Name: "--unit-test--",
ManagementPortalURL: "--management-portal-url",
PublishSettingsURL: "--publish-settings-url--",
ServiceManagementEndpoint: "--service-management-endpoint--",
ResourceManagerEndpoint: "--resource-management-endpoint--",
ActiveDirectoryEndpoint: "--active-directory-endpoint--",
GalleryEndpoint: "--gallery-endpoint--",
KeyVaultEndpoint: "--key-vault--endpoint--",
GraphEndpoint: "--graph-endpoint--",
StorageEndpointSuffix: "--storage-endpoint-suffix--",
SQLDatabaseDNSSuffix: "--sql-database-dns-suffix--",
TrafficManagerDNSSuffix: "--traffic-manager-dns-suffix--",
KeyVaultDNSSuffix: "--key-vault-dns-suffix--",
ServiceBusEndpointSuffix: "--service-bus-endpoint-suffix--",
ServiceManagementVMDNSSuffix: "--asm-vm-dns-suffix--",
ResourceManagerVMDNSSuffix: "--arm-vm-dns-suffix--",
ContainerRegistryDNSSuffix: "--container-registry-dns-suffix--",
}
func TestEnvironment_EnvironmentFromFile(t *testing.T) {
got, err := EnvironmentFromFile(filepath.Join("testdata", "test_environment_1.json"))
if err != nil {
t.Error(err)
}
if got != testEnvironment1 {
t.Logf("got: %v want: %v", got, testEnvironment1)
t.Fail()
}
}
func TestEnvironment_EnvironmentFromName_Stack(t *testing.T) {
_, currentFile, _, _ := runtime.Caller(0)
prevEnvFilepathValue := os.Getenv(EnvironmentFilepathName)
os.Setenv(EnvironmentFilepathName, filepath.Join(path.Dir(currentFile), "testdata", "test_environment_1.json"))
defer os.Setenv(EnvironmentFilepathName, prevEnvFilepathValue)
got, err := EnvironmentFromName("AZURESTACKCLOUD")
if err != nil {
t.Error(err)
}
if got != testEnvironment1 {
t.Logf("got: %v want: %v", got, testEnvironment1)
t.Fail()
}
}
func TestEnvironmentFromName(t *testing.T) {
name := "azurechinacloud"
if env, _ := EnvironmentFromName(name); env != ChinaCloud {
t.Errorf("Expected to get ChinaCloud for %q", name)
}
name = "AzureChinaCloud"
if env, _ := EnvironmentFromName(name); env != ChinaCloud {
t.Errorf("Expected to get ChinaCloud for %q", name)
}
name = "azuregermancloud"
if env, _ := EnvironmentFromName(name); env != GermanCloud {
t.Errorf("Expected to get GermanCloud for %q", name)
}
name = "AzureGermanCloud"
if env, _ := EnvironmentFromName(name); env != GermanCloud {
t.Errorf("Expected to get GermanCloud for %q", name)
}
name = "azurepubliccloud"
if env, _ := EnvironmentFromName(name); env != PublicCloud {
t.Errorf("Expected to get PublicCloud for %q", name)
}
name = "AzurePublicCloud"
if env, _ := EnvironmentFromName(name); env != PublicCloud {
t.Errorf("Expected to get PublicCloud for %q", name)
}
name = "azureusgovernmentcloud"
if env, _ := EnvironmentFromName(name); env != USGovernmentCloud {
t.Errorf("Expected to get USGovernmentCloud for %q", name)
}
name = "AzureUSGovernmentCloud"
if env, _ := EnvironmentFromName(name); env != USGovernmentCloud {
t.Errorf("Expected to get USGovernmentCloud for %q", name)
}
name = "thisisnotarealcloudenv"
if _, err := EnvironmentFromName(name); err == nil {
t.Errorf("Expected to get an error for %q", name)
}
}
func TestDeserializeEnvironment(t *testing.T) {
env := `{
"name": "--name--",
"ActiveDirectoryEndpoint": "--active-directory-endpoint--",
"galleryEndpoint": "--gallery-endpoint--",
"graphEndpoint": "--graph-endpoint--",
"keyVaultDNSSuffix": "--key-vault-dns-suffix--",
"keyVaultEndpoint": "--key-vault-endpoint--",
"managementPortalURL": "--management-portal-url--",
"publishSettingsURL": "--publish-settings-url--",
"resourceManagerEndpoint": "--resource-manager-endpoint--",
"serviceBusEndpointSuffix": "--service-bus-endpoint-suffix--",
"serviceManagementEndpoint": "--service-management-endpoint--",
"sqlDatabaseDNSSuffix": "--sql-database-dns-suffix--",
"storageEndpointSuffix": "--storage-endpoint-suffix--",
"trafficManagerDNSSuffix": "--traffic-manager-dns-suffix--",
"serviceManagementVMDNSSuffix": "--asm-vm-dns-suffix--",
"resourceManagerVMDNSSuffix": "--arm-vm-dns-suffix--",
"containerRegistryDNSSuffix": "--container-registry-dns-suffix--"
}`
testSubject := Environment{}
err := json.Unmarshal([]byte(env), &testSubject)
if err != nil {
t.Fatalf("failed to unmarshal: %s", err)
}
if "--name--" != testSubject.Name {
t.Errorf("Expected Name to be \"--name--\", but got %q", testSubject.Name)
}
if "--management-portal-url--" != testSubject.ManagementPortalURL {
t.Errorf("Expected ManagementPortalURL to be \"--management-portal-url--\", but got %q", testSubject.ManagementPortalURL)
}
if "--publish-settings-url--" != testSubject.PublishSettingsURL {
t.Errorf("Expected PublishSettingsURL to be \"--publish-settings-url--\", but got %q", testSubject.PublishSettingsURL)
}
if "--service-management-endpoint--" != testSubject.ServiceManagementEndpoint {
t.Errorf("Expected ServiceManagementEndpoint to be \"--service-management-endpoint--\", but got %q", testSubject.ServiceManagementEndpoint)
}
if "--resource-manager-endpoint--" != testSubject.ResourceManagerEndpoint {
t.Errorf("Expected ResourceManagerEndpoint to be \"--resource-manager-endpoint--\", but got %q", testSubject.ResourceManagerEndpoint)
}
if "--active-directory-endpoint--" != testSubject.ActiveDirectoryEndpoint {
t.Errorf("Expected ActiveDirectoryEndpoint to be \"--active-directory-endpoint--\", but got %q", testSubject.ActiveDirectoryEndpoint)
}
if "--gallery-endpoint--" != testSubject.GalleryEndpoint {
t.Errorf("Expected GalleryEndpoint to be \"--gallery-endpoint--\", but got %q", testSubject.GalleryEndpoint)
}
if "--key-vault-endpoint--" != testSubject.KeyVaultEndpoint {
t.Errorf("Expected KeyVaultEndpoint to be \"--key-vault-endpoint--\", but got %q", testSubject.KeyVaultEndpoint)
}
if "--graph-endpoint--" != testSubject.GraphEndpoint {
t.Errorf("Expected GraphEndpoint to be \"--graph-endpoint--\", but got %q", testSubject.GraphEndpoint)
}
if "--storage-endpoint-suffix--" != testSubject.StorageEndpointSuffix {
t.Errorf("Expected StorageEndpointSuffix to be \"--storage-endpoint-suffix--\", but got %q", testSubject.StorageEndpointSuffix)
}
if "--sql-database-dns-suffix--" != testSubject.SQLDatabaseDNSSuffix {
t.Errorf("Expected sql-database-dns-suffix to be \"--sql-database-dns-suffix--\", but got %q", testSubject.SQLDatabaseDNSSuffix)
}
if "--key-vault-dns-suffix--" != testSubject.KeyVaultDNSSuffix {
t.Errorf("Expected StorageEndpointSuffix to be \"--key-vault-dns-suffix--\", but got %q", testSubject.KeyVaultDNSSuffix)
}
if "--service-bus-endpoint-suffix--" != testSubject.ServiceBusEndpointSuffix {
t.Errorf("Expected StorageEndpointSuffix to be \"--service-bus-endpoint-suffix--\", but got %q", testSubject.ServiceBusEndpointSuffix)
}
if "--asm-vm-dns-suffix--" != testSubject.ServiceManagementVMDNSSuffix {
t.Errorf("Expected ServiceManagementVMDNSSuffix to be \"--asm-vm-dns-suffix--\", but got %q", testSubject.ServiceManagementVMDNSSuffix)
}
if "--arm-vm-dns-suffix--" != testSubject.ResourceManagerVMDNSSuffix {
t.Errorf("Expected ResourceManagerVMDNSSuffix to be \"--arm-vm-dns-suffix--\", but got %q", testSubject.ResourceManagerVMDNSSuffix)
}
if "--container-registry-dns-suffix--" != testSubject.ContainerRegistryDNSSuffix {
t.Errorf("Expected ContainerRegistryDNSSuffix to be \"--container-registry-dns-suffix--\", but got %q", testSubject.ContainerRegistryDNSSuffix)
}
}
func TestRoundTripSerialization(t *testing.T) {
env := Environment{
Name: "--unit-test--",
ManagementPortalURL: "--management-portal-url",
PublishSettingsURL: "--publish-settings-url--",
ServiceManagementEndpoint: "--service-management-endpoint--",
ResourceManagerEndpoint: "--resource-management-endpoint--",
ActiveDirectoryEndpoint: "--active-directory-endpoint--",
GalleryEndpoint: "--gallery-endpoint--",
KeyVaultEndpoint: "--key-vault--endpoint--",
GraphEndpoint: "--graph-endpoint--",
StorageEndpointSuffix: "--storage-endpoint-suffix--",
SQLDatabaseDNSSuffix: "--sql-database-dns-suffix--",
TrafficManagerDNSSuffix: "--traffic-manager-dns-suffix--",
KeyVaultDNSSuffix: "--key-vault-dns-suffix--",
ServiceBusEndpointSuffix: "--service-bus-endpoint-suffix--",
ServiceManagementVMDNSSuffix: "--asm-vm-dns-suffix--",
ResourceManagerVMDNSSuffix: "--arm-vm-dns-suffix--",
ContainerRegistryDNSSuffix: "--container-registry-dns-suffix--",
}
bytes, err := json.Marshal(env)
if err != nil {
t.Fatalf("failed to marshal: %s", err)
}
testSubject := Environment{}
err = json.Unmarshal(bytes, &testSubject)
if err != nil {
t.Fatalf("failed to unmarshal: %s", err)
}
if env.Name != testSubject.Name {
t.Errorf("Expected Name to be %q, but got %q", env.Name, testSubject.Name)
}
if env.ManagementPortalURL != testSubject.ManagementPortalURL {
t.Errorf("Expected ManagementPortalURL to be %q, but got %q", env.ManagementPortalURL, testSubject.ManagementPortalURL)
}
if env.PublishSettingsURL != testSubject.PublishSettingsURL {
t.Errorf("Expected PublishSettingsURL to be %q, but got %q", env.PublishSettingsURL, testSubject.PublishSettingsURL)
}
if env.ServiceManagementEndpoint != testSubject.ServiceManagementEndpoint {
t.Errorf("Expected ServiceManagementEndpoint to be %q, but got %q", env.ServiceManagementEndpoint, testSubject.ServiceManagementEndpoint)
}
if env.ResourceManagerEndpoint != testSubject.ResourceManagerEndpoint {
t.Errorf("Expected ResourceManagerEndpoint to be %q, but got %q", env.ResourceManagerEndpoint, testSubject.ResourceManagerEndpoint)
}
if env.ActiveDirectoryEndpoint != testSubject.ActiveDirectoryEndpoint {
t.Errorf("Expected ActiveDirectoryEndpoint to be %q, but got %q", env.ActiveDirectoryEndpoint, testSubject.ActiveDirectoryEndpoint)
}
if env.GalleryEndpoint != testSubject.GalleryEndpoint {
t.Errorf("Expected GalleryEndpoint to be %q, but got %q", env.GalleryEndpoint, testSubject.GalleryEndpoint)
}
if env.KeyVaultEndpoint != testSubject.KeyVaultEndpoint {
t.Errorf("Expected KeyVaultEndpoint to be %q, but got %q", env.KeyVaultEndpoint, testSubject.KeyVaultEndpoint)
}
if env.GraphEndpoint != testSubject.GraphEndpoint {
t.Errorf("Expected GraphEndpoint to be %q, but got %q", env.GraphEndpoint, testSubject.GraphEndpoint)
}
if env.StorageEndpointSuffix != testSubject.StorageEndpointSuffix {
t.Errorf("Expected StorageEndpointSuffix to be %q, but got %q", env.StorageEndpointSuffix, testSubject.StorageEndpointSuffix)
}
if env.SQLDatabaseDNSSuffix != testSubject.SQLDatabaseDNSSuffix {
t.Errorf("Expected SQLDatabaseDNSSuffix to be %q, but got %q", env.SQLDatabaseDNSSuffix, testSubject.SQLDatabaseDNSSuffix)
}
if env.TrafficManagerDNSSuffix != testSubject.TrafficManagerDNSSuffix {
t.Errorf("Expected TrafficManagerDNSSuffix to be %q, but got %q", env.TrafficManagerDNSSuffix, testSubject.TrafficManagerDNSSuffix)
}
if env.KeyVaultDNSSuffix != testSubject.KeyVaultDNSSuffix {
t.Errorf("Expected KeyVaultDNSSuffix to be %q, but got %q", env.KeyVaultDNSSuffix, testSubject.KeyVaultDNSSuffix)
}
if env.ServiceBusEndpointSuffix != testSubject.ServiceBusEndpointSuffix {
t.Errorf("Expected ServiceBusEndpointSuffix to be %q, but got %q", env.ServiceBusEndpointSuffix, testSubject.ServiceBusEndpointSuffix)
}
if env.ServiceManagementVMDNSSuffix != testSubject.ServiceManagementVMDNSSuffix {
t.Errorf("Expected ServiceManagementVMDNSSuffix to be %q, but got %q", env.ServiceManagementVMDNSSuffix, testSubject.ServiceManagementVMDNSSuffix)
}
if env.ResourceManagerVMDNSSuffix != testSubject.ResourceManagerVMDNSSuffix {
t.Errorf("Expected ResourceManagerVMDNSSuffix to be %q, but got %q", env.ResourceManagerVMDNSSuffix, testSubject.ResourceManagerVMDNSSuffix)
}
if env.ContainerRegistryDNSSuffix != testSubject.ContainerRegistryDNSSuffix {
t.Errorf("Expected ContainerRegistryDNSSuffix to be %q, but got %q", env.ContainerRegistryDNSSuffix, testSubject.ContainerRegistryDNSSuffix)
}
}

View File

@@ -1,127 +0,0 @@
# autorest azure example
## Usage (device mode)
This shows how to use the example for device auth.
1. Execute this. It will save your token to /tmp/azure-example-token:
```
./example -tenantId "13de0a15-b5db-44b9-b682-b4ba82afbd29" -subscriptionId "aff271ee-e9be-4441-b9bb-42f5af4cbaeb" -mode "device" -tokenCachePath "/tmp/azure-example-token"
```
2. Execute it again, it will load the token from cache and not prompt for auth again.
## Usage (certificate mode)
This example covers how to make an authenticated call to the Azure Resource Manager APIs, using certificate-based authentication.
0. Export some required variables
```
export SUBSCRIPTION_ID="aff271ee-e9be-4441-b9bb-42f5af4cbaeb"
export TENANT_ID="13de0a15-b5db-44b9-b682-b4ba82afbd29"
export RESOURCE_GROUP="someresourcegroup"
```
* replace both values with your own
1. Create a private key
```
openssl genrsa -out "example.key" 2048
```
2. Create the certificate
```
openssl req -new -key "example.key" -subj "/CN=example" -out "example.csr"
openssl x509 -req -in "example.csr" -signkey "example.key" -out "example.crt" -days 10000
```
3. Create the PKCS12 version of the certificate (with no password)
```
openssl pkcs12 -export -out "example.pfx" -inkey "example.key" -in "example.crt" -passout pass:
```
4. Register a new Azure AD Application with the certificate contents
```
certificateContents="$(tail -n+2 "example.key" | head -n-1)"
azure ad app create \
--name "example-azuread-app" \
--home-page="http://example-azuread-app/home" \
--identifier-uris "http://example-azuread-app/app" \
--key-usage "Verify" \
--end-date "2020-01-01" \
--key-value "${certificateContents}"
```
5. Create a new service principal using the "Application Id" from the previous step
```
azure ad sp create "APPLICATION_ID"
```
* Replace APPLICATION_ID with the "Application Id" returned in step 4
6. Grant your service principal necessary permissions
```
azure role assignment create \
--resource-group "${RESOURCE_GROUP}" \
--roleName "Contributor" \
--subscription "${SUBSCRIPTION_ID}" \
--spn "http://example-azuread-app/app"
```
* Replace SUBSCRIPTION_ID with your subscription id
* Replace RESOURCE_GROUP with the resource group for the assignment
* Ensure that the `spn` parameter matches an `identifier-url` from Step 4
7. Run this example app to see your resource groups
```
go run main.go \
--tenantId="${TENANT_ID}" \
--subscriptionId="${SUBSCRIPTION_ID}" \
--applicationId="http://example-azuread-app/app" \
--certificatePath="certificate.pfx"
```
You should see something like this as output:
```
2015/11/08 18:28:39 Using these settings:
2015/11/08 18:28:39 * certificatePath: certificate.pfx
2015/11/08 18:28:39 * applicationID: http://example-azuread-app/app
2015/11/08 18:28:39 * tenantID: 13de0a15-b5db-44b9-b682-b4ba82afbd29
2015/11/08 18:28:39 * subscriptionID: aff271ee-e9be-4441-b9bb-42f5af4cbaeb
2015/11/08 18:28:39 loading certificate...
2015/11/08 18:28:39 retrieve oauth token...
2015/11/08 18:28:39 querying the list of resource groups...
2015/11/08 18:28:50
2015/11/08 18:28:50 Groups: {"value":[{"id":"/subscriptions/aff271ee-e9be-4441-b9bb-42f5af4cbaeb/resourceGroups/kube-66f30810","name":"kube-66f30810","location":"westus","tags":{},"properties":{"provisioningState":"Succeeded"}}]}
```
## Notes
You may need to wait sometime between executing step 4, step 5 and step 6. If you issue those requests too quickly, you might hit an AD server that is not consistent with the server where the resource was created.

View File

@@ -1,272 +0,0 @@
package main
// Copyright 2017 Microsoft Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
import (
"crypto/rsa"
"crypto/x509"
"encoding/json"
"flag"
"fmt"
"io/ioutil"
"log"
"net/http"
"strings"
"github.com/Azure/go-autorest/autorest"
"github.com/Azure/go-autorest/autorest/adal"
"github.com/Azure/go-autorest/autorest/azure"
"golang.org/x/crypto/pkcs12"
)
const (
resourceGroupURLTemplate = "https://management.azure.com"
apiVersion = "2015-01-01"
nativeAppClientID = "a87032a7-203c-4bf7-913c-44c50d23409a"
resource = "https://management.core.windows.net/"
)
var (
mode string
tenantID string
subscriptionID string
applicationID string
tokenCachePath string
forceRefresh bool
impatient bool
certificatePath string
)
func init() {
flag.StringVar(&mode, "mode", "device", "mode of operation for SPT creation")
flag.StringVar(&certificatePath, "certificatePath", "", "path to pk12/pfx certificate")
flag.StringVar(&applicationID, "applicationId", "", "application id")
flag.StringVar(&tenantID, "tenantId", "", "tenant id")
flag.StringVar(&subscriptionID, "subscriptionId", "", "subscription id")
flag.StringVar(&tokenCachePath, "tokenCachePath", "", "location of oauth token cache")
flag.BoolVar(&forceRefresh, "forceRefresh", false, "pass true to force a token refresh")
flag.Parse()
log.Printf("mode(%s) certPath(%s) appID(%s) tenantID(%s), subID(%s)\n",
mode, certificatePath, applicationID, tenantID, subscriptionID)
if mode == "certificate" &&
(strings.TrimSpace(tenantID) == "" || strings.TrimSpace(subscriptionID) == "") {
log.Fatalln("Bad usage. Using certificate mode. Please specify tenantID, subscriptionID")
}
if mode != "certificate" && mode != "device" {
log.Fatalln("Bad usage. Mode must be one of 'certificate' or 'device'.")
}
if mode == "device" && strings.TrimSpace(applicationID) == "" {
log.Println("Using device mode auth. Will use `azkube` clientID since none was specified on the comand line.")
applicationID = nativeAppClientID
}
if mode == "certificate" && strings.TrimSpace(certificatePath) == "" {
log.Fatalln("Bad usage. Mode 'certificate' requires the 'certificatePath' argument.")
}
if strings.TrimSpace(tenantID) == "" || strings.TrimSpace(subscriptionID) == "" || strings.TrimSpace(applicationID) == "" {
log.Fatalln("Bad usage. Must specify the 'tenantId' and 'subscriptionId'")
}
}
func getSptFromCachedToken(oauthConfig adal.OAuthConfig, clientID, resource string, callbacks ...adal.TokenRefreshCallback) (*adal.ServicePrincipalToken, error) {
token, err := adal.LoadToken(tokenCachePath)
if err != nil {
return nil, fmt.Errorf("failed to load token from cache: %v", err)
}
spt, _ := adal.NewServicePrincipalTokenFromManualToken(
oauthConfig,
clientID,
resource,
*token,
callbacks...)
return spt, nil
}
func decodePkcs12(pkcs []byte, password string) (*x509.Certificate, *rsa.PrivateKey, error) {
privateKey, certificate, err := pkcs12.Decode(pkcs, password)
if err != nil {
return nil, nil, err
}
rsaPrivateKey, isRsaKey := privateKey.(*rsa.PrivateKey)
if !isRsaKey {
return nil, nil, fmt.Errorf("PKCS#12 certificate must contain an RSA private key")
}
return certificate, rsaPrivateKey, nil
}
func getSptFromCertificate(oauthConfig adal.OAuthConfig, clientID, resource, certicatePath string, callbacks ...adal.TokenRefreshCallback) (*adal.ServicePrincipalToken, error) {
certData, err := ioutil.ReadFile(certificatePath)
if err != nil {
return nil, fmt.Errorf("failed to read the certificate file (%s): %v", certificatePath, err)
}
certificate, rsaPrivateKey, err := decodePkcs12(certData, "")
if err != nil {
return nil, fmt.Errorf("failed to decode pkcs12 certificate while creating spt: %v", err)
}
spt, _ := adal.NewServicePrincipalTokenFromCertificate(
oauthConfig,
clientID,
certificate,
rsaPrivateKey,
resource,
callbacks...)
return spt, nil
}
func getSptFromDeviceFlow(oauthConfig adal.OAuthConfig, clientID, resource string, callbacks ...adal.TokenRefreshCallback) (*adal.ServicePrincipalToken, error) {
oauthClient := &autorest.Client{}
deviceCode, err := adal.InitiateDeviceAuth(oauthClient, oauthConfig, clientID, resource)
if err != nil {
return nil, fmt.Errorf("failed to start device auth flow: %s", err)
}
fmt.Println(*deviceCode.Message)
token, err := adal.WaitForUserCompletion(oauthClient, deviceCode)
if err != nil {
return nil, fmt.Errorf("failed to finish device auth flow: %s", err)
}
spt, err := adal.NewServicePrincipalTokenFromManualToken(
oauthConfig,
clientID,
resource,
*token,
callbacks...)
if err != nil {
return nil, fmt.Errorf("failed to get oauth token from device flow: %v", err)
}
return spt, nil
}
func printResourceGroups(client *autorest.Client) error {
p := map[string]interface{}{"subscription-id": subscriptionID}
q := map[string]interface{}{"api-version": apiVersion}
req, _ := autorest.Prepare(&http.Request{},
autorest.AsGet(),
autorest.WithBaseURL(resourceGroupURLTemplate),
autorest.WithPathParameters("/subscriptions/{subscription-id}/resourcegroups", p),
autorest.WithQueryParameters(q))
resp, err := autorest.SendWithSender(client, req)
if err != nil {
return err
}
value := struct {
ResourceGroups []struct {
Name string `json:"name"`
} `json:"value"`
}{}
defer resp.Body.Close()
dec := json.NewDecoder(resp.Body)
err = dec.Decode(&value)
if err != nil {
return err
}
var groupNames = make([]string, len(value.ResourceGroups))
for i, name := range value.ResourceGroups {
groupNames[i] = name.Name
}
log.Println("Groups:", strings.Join(groupNames, ", "))
return err
}
func saveToken(spt adal.Token) {
if tokenCachePath != "" {
err := adal.SaveToken(tokenCachePath, 0600, spt)
if err != nil {
log.Println("error saving token", err)
} else {
log.Println("saved token to", tokenCachePath)
}
}
}
func main() {
var spt *adal.ServicePrincipalToken
var err error
callback := func(t adal.Token) error {
log.Println("refresh callback was called")
saveToken(t)
return nil
}
oauthConfig, err := adal.NewOAuthConfig(azure.PublicCloud.ActiveDirectoryEndpoint, tenantID)
if err != nil {
panic(err)
}
if tokenCachePath != "" {
log.Println("tokenCachePath specified; attempting to load from", tokenCachePath)
spt, err = getSptFromCachedToken(*oauthConfig, applicationID, resource, callback)
if err != nil {
spt = nil // just in case, this is the condition below
log.Println("loading from cache failed:", err)
}
}
if spt == nil {
log.Println("authenticating via 'mode'", mode)
switch mode {
case "device":
spt, err = getSptFromDeviceFlow(*oauthConfig, applicationID, resource, callback)
case "certificate":
spt, err = getSptFromCertificate(*oauthConfig, applicationID, resource, certificatePath, callback)
}
if err != nil {
log.Fatalln("failed to retrieve token:", err)
}
// should save it as soon as you get it since Refresh won't be called for some time
if tokenCachePath != "" {
saveToken(spt.Token)
}
}
client := &autorest.Client{}
client.Authorizer = autorest.NewBearerAuthorizer(spt)
printResourceGroups(client)
if forceRefresh {
err = spt.Refresh()
if err != nil {
panic(err)
}
printResourceGroups(client)
}
}

View File

@@ -1,203 +0,0 @@
// Copyright 2017 Microsoft Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package azure
import (
"errors"
"fmt"
"net/http"
"net/url"
"strings"
"time"
"github.com/Azure/go-autorest/autorest"
)
// DoRetryWithRegistration tries to register the resource provider in case it is unregistered.
// It also handles request retries
func DoRetryWithRegistration(client autorest.Client) autorest.SendDecorator {
return func(s autorest.Sender) autorest.Sender {
return autorest.SenderFunc(func(r *http.Request) (resp *http.Response, err error) {
rr := autorest.NewRetriableRequest(r)
for currentAttempt := 0; currentAttempt < client.RetryAttempts; currentAttempt++ {
err = rr.Prepare()
if err != nil {
return resp, err
}
resp, err = autorest.SendWithSender(s, rr.Request(),
autorest.DoRetryForStatusCodes(client.RetryAttempts, client.RetryDuration, autorest.StatusCodesForRetry...),
)
if err != nil {
return resp, err
}
if resp.StatusCode != http.StatusConflict {
return resp, err
}
var re RequestError
err = autorest.Respond(
resp,
autorest.ByUnmarshallingJSON(&re),
)
if err != nil {
return resp, err
}
err = re
if re.ServiceError != nil && re.ServiceError.Code == "MissingSubscriptionRegistration" {
regErr := register(client, r, re)
if regErr != nil {
return resp, fmt.Errorf("failed auto registering Resource Provider: %s. Original error: %s", regErr, err)
}
}
}
return resp, fmt.Errorf("failed request: %s", err)
})
}
}
func getProvider(re RequestError) (string, error) {
if re.ServiceError != nil {
if re.ServiceError.Details != nil && len(*re.ServiceError.Details) > 0 {
detail := (*re.ServiceError.Details)[0].(map[string]interface{})
return detail["target"].(string), nil
}
}
return "", errors.New("provider was not found in the response")
}
func register(client autorest.Client, originalReq *http.Request, re RequestError) error {
subID := getSubscription(originalReq.URL.Path)
if subID == "" {
return errors.New("missing parameter subscriptionID to register resource provider")
}
providerName, err := getProvider(re)
if err != nil {
return fmt.Errorf("missing parameter provider to register resource provider: %s", err)
}
newURL := url.URL{
Scheme: originalReq.URL.Scheme,
Host: originalReq.URL.Host,
}
// taken from the resources SDK
// with almost identical code, this sections are easier to mantain
// It is also not a good idea to import the SDK here
// https://github.com/Azure/azure-sdk-for-go/blob/9f366792afa3e0ddaecdc860e793ba9d75e76c27/arm/resources/resources/providers.go#L252
pathParameters := map[string]interface{}{
"resourceProviderNamespace": autorest.Encode("path", providerName),
"subscriptionId": autorest.Encode("path", subID),
}
const APIVersion = "2016-09-01"
queryParameters := map[string]interface{}{
"api-version": APIVersion,
}
preparer := autorest.CreatePreparer(
autorest.AsPost(),
autorest.WithBaseURL(newURL.String()),
autorest.WithPathParameters("/subscriptions/{subscriptionId}/providers/{resourceProviderNamespace}/register", pathParameters),
autorest.WithQueryParameters(queryParameters),
)
req, err := preparer.Prepare(&http.Request{})
if err != nil {
return err
}
req.Cancel = originalReq.Cancel
resp, err := autorest.SendWithSender(client, req,
autorest.DoRetryForStatusCodes(client.RetryAttempts, client.RetryDuration, autorest.StatusCodesForRetry...),
)
if err != nil {
return err
}
type Provider struct {
RegistrationState *string `json:"registrationState,omitempty"`
}
var provider Provider
err = autorest.Respond(
resp,
WithErrorUnlessStatusCode(http.StatusOK),
autorest.ByUnmarshallingJSON(&provider),
autorest.ByClosing(),
)
if err != nil {
return err
}
// poll for registered provisioning state
now := time.Now()
for err == nil && time.Since(now) < client.PollingDuration {
// taken from the resources SDK
// https://github.com/Azure/azure-sdk-for-go/blob/9f366792afa3e0ddaecdc860e793ba9d75e76c27/arm/resources/resources/providers.go#L45
preparer := autorest.CreatePreparer(
autorest.AsGet(),
autorest.WithBaseURL(newURL.String()),
autorest.WithPathParameters("/subscriptions/{subscriptionId}/providers/{resourceProviderNamespace}", pathParameters),
autorest.WithQueryParameters(queryParameters),
)
req, err = preparer.Prepare(&http.Request{})
if err != nil {
return err
}
req.Cancel = originalReq.Cancel
resp, err := autorest.SendWithSender(client.Sender, req,
autorest.DoRetryForStatusCodes(client.RetryAttempts, client.RetryDuration, autorest.StatusCodesForRetry...),
)
if err != nil {
return err
}
err = autorest.Respond(
resp,
WithErrorUnlessStatusCode(http.StatusOK),
autorest.ByUnmarshallingJSON(&provider),
autorest.ByClosing(),
)
if err != nil {
return err
}
if provider.RegistrationState != nil &&
*provider.RegistrationState == "Registered" {
break
}
delayed := autorest.DelayWithRetryAfter(resp, originalReq.Cancel)
if !delayed {
autorest.DelayForBackoff(client.PollingDelay, 0, originalReq.Cancel)
}
}
if !(time.Since(now) < client.PollingDuration) {
return errors.New("polling for resource provider registration has exceeded the polling duration")
}
return err
}
func getSubscription(path string) string {
parts := strings.Split(path, "/")
for i, v := range parts {
if v == "subscriptions" && (i+1) < len(parts) {
return parts[i+1]
}
}
return ""
}

View File

@@ -1,81 +0,0 @@
// Copyright 2017 Microsoft Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package azure
import (
"net/http"
"testing"
"time"
"github.com/Azure/go-autorest/autorest"
"github.com/Azure/go-autorest/autorest/mocks"
)
func TestDoRetryWithRegistration(t *testing.T) {
client := mocks.NewSender()
// first response, should retry because it is a transient error
client.AppendResponse(mocks.NewResponseWithStatus("Internal server error", http.StatusInternalServerError))
// response indicates the resource provider has not been registered
client.AppendResponse(mocks.NewResponseWithBodyAndStatus(mocks.NewBody(`{
"error":{
"code":"MissingSubscriptionRegistration",
"message":"The subscription registration is in 'Unregistered' state. The subscription must be registered to use namespace 'Microsoft.EventGrid'. See https://aka.ms/rps-not-found for how to register subscriptions.",
"details":[
{
"code":"MissingSubscriptionRegistration",
"target":"Microsoft.EventGrid",
"message":"The subscription registration is in 'Unregistered' state. The subscription must be registered to use namespace 'Microsoft.EventGrid'. See https://aka.ms/rps-not-found for how to register subscriptions."
}
]
}
}
`), http.StatusConflict, "MissingSubscriptionRegistration"))
// first poll response, still not ready
client.AppendResponse(mocks.NewResponseWithBodyAndStatus(mocks.NewBody(`{
"registrationState": "Registering"
}
`), http.StatusOK, "200 OK"))
// last poll response, respurce provider has been registered
client.AppendResponse(mocks.NewResponseWithBodyAndStatus(mocks.NewBody(`{
"registrationState": "Registered"
}
`), http.StatusOK, "200 OK"))
// retry original request, response is successful
client.AppendResponse(mocks.NewResponseWithStatus("200 OK", http.StatusOK))
req := mocks.NewRequestForURL("https://lol/subscriptions/rofl")
req.Body = mocks.NewBody("lolol")
r, err := autorest.SendWithSender(client, req,
DoRetryWithRegistration(autorest.Client{
PollingDelay: time.Second,
PollingDuration: time.Second * 10,
RetryAttempts: 5,
RetryDuration: time.Second,
Sender: client,
}),
)
if err != nil {
t.Fatalf("got error: %v", err)
}
autorest.Respond(r,
autorest.ByDiscardingBody(),
autorest.ByClosing(),
)
if r.StatusCode != http.StatusOK {
t.Fatalf("azure: Sender#DoRetryWithRegistration -- Got: StatusCode %v; Want: StatusCode 200 OK", r.StatusCode)
}
}

View File

@@ -1,19 +0,0 @@
{
"name": "--unit-test--",
"managementPortalURL": "--management-portal-url",
"publishSettingsURL": "--publish-settings-url--",
"serviceManagementEndpoint": "--service-management-endpoint--",
"resourceManagerEndpoint": "--resource-management-endpoint--",
"activeDirectoryEndpoint": "--active-directory-endpoint--",
"galleryEndpoint": "--gallery-endpoint--",
"keyVaultEndpoint": "--key-vault--endpoint--",
"graphEndpoint": "--graph-endpoint--",
"storageEndpointSuffix": "--storage-endpoint-suffix--",
"sqlDatabaseDNSSuffix": "--sql-database-dns-suffix--",
"trafficManagerDNSSuffix": "--traffic-manager-dns-suffix--",
"keyVaultDNSSuffix": "--key-vault-dns-suffix--",
"serviceBusEndpointSuffix": "--service-bus-endpoint-suffix--",
"serviceManagementVMDNSSuffix": "--asm-vm-dns-suffix--",
"resourceManagerVMDNSSuffix": "--arm-vm-dns-suffix--",
"containerRegistryDNSSuffix": "--container-registry-dns-suffix--"
}

View File

@@ -1,254 +0,0 @@
package autorest
// Copyright 2017 Microsoft Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
import (
"bytes"
"fmt"
"io"
"io/ioutil"
"log"
"net/http"
"net/http/cookiejar"
"runtime"
"time"
)
const (
// DefaultPollingDelay is a reasonable delay between polling requests.
DefaultPollingDelay = 60 * time.Second
// DefaultPollingDuration is a reasonable total polling duration.
DefaultPollingDuration = 15 * time.Minute
// DefaultRetryAttempts is number of attempts for retry status codes (5xx).
DefaultRetryAttempts = 3
// DefaultRetryDuration is the duration to wait between retries.
DefaultRetryDuration = 30 * time.Second
)
var (
// defaultUserAgent builds a string containing the Go version, system archityecture and OS,
// and the go-autorest version.
defaultUserAgent = fmt.Sprintf("Go/%s (%s-%s) go-autorest/%s",
runtime.Version(),
runtime.GOARCH,
runtime.GOOS,
Version(),
)
// StatusCodesForRetry are a defined group of status code for which the client will retry
StatusCodesForRetry = []int{
http.StatusRequestTimeout, // 408
http.StatusTooManyRequests, // 429
http.StatusInternalServerError, // 500
http.StatusBadGateway, // 502
http.StatusServiceUnavailable, // 503
http.StatusGatewayTimeout, // 504
}
)
const (
requestFormat = `HTTP Request Begin ===================================================
%s
===================================================== HTTP Request End
`
responseFormat = `HTTP Response Begin ===================================================
%s
===================================================== HTTP Response End
`
)
// Response serves as the base for all responses from generated clients. It provides access to the
// last http.Response.
type Response struct {
*http.Response `json:"-"`
}
// LoggingInspector implements request and response inspectors that log the full request and
// response to a supplied log.
type LoggingInspector struct {
Logger *log.Logger
}
// WithInspection returns a PrepareDecorator that emits the http.Request to the supplied logger. The
// body is restored after being emitted.
//
// Note: Since it reads the entire Body, this decorator should not be used where body streaming is
// important. It is best used to trace JSON or similar body values.
func (li LoggingInspector) WithInspection() PrepareDecorator {
return func(p Preparer) Preparer {
return PreparerFunc(func(r *http.Request) (*http.Request, error) {
var body, b bytes.Buffer
defer r.Body.Close()
r.Body = ioutil.NopCloser(io.TeeReader(r.Body, &body))
if err := r.Write(&b); err != nil {
return nil, fmt.Errorf("Failed to write response: %v", err)
}
li.Logger.Printf(requestFormat, b.String())
r.Body = ioutil.NopCloser(&body)
return p.Prepare(r)
})
}
}
// ByInspecting returns a RespondDecorator that emits the http.Response to the supplied logger. The
// body is restored after being emitted.
//
// Note: Since it reads the entire Body, this decorator should not be used where body streaming is
// important. It is best used to trace JSON or similar body values.
func (li LoggingInspector) ByInspecting() RespondDecorator {
return func(r Responder) Responder {
return ResponderFunc(func(resp *http.Response) error {
var body, b bytes.Buffer
defer resp.Body.Close()
resp.Body = ioutil.NopCloser(io.TeeReader(resp.Body, &body))
if err := resp.Write(&b); err != nil {
return fmt.Errorf("Failed to write response: %v", err)
}
li.Logger.Printf(responseFormat, b.String())
resp.Body = ioutil.NopCloser(&body)
return r.Respond(resp)
})
}
}
// Client is the base for autorest generated clients. It provides default, "do nothing"
// implementations of an Authorizer, RequestInspector, and ResponseInspector. It also returns the
// standard, undecorated http.Client as a default Sender.
//
// Generated clients should also use Error (see NewError and NewErrorWithError) for errors and
// return responses that compose with Response.
//
// Most customization of generated clients is best achieved by supplying a custom Authorizer, custom
// RequestInspector, and / or custom ResponseInspector. Users may log requests, implement circuit
// breakers (see https://msdn.microsoft.com/en-us/library/dn589784.aspx) or otherwise influence
// sending the request by providing a decorated Sender.
type Client struct {
Authorizer Authorizer
Sender Sender
RequestInspector PrepareDecorator
ResponseInspector RespondDecorator
// PollingDelay sets the polling frequency used in absence of a Retry-After HTTP header
PollingDelay time.Duration
// PollingDuration sets the maximum polling time after which an error is returned.
PollingDuration time.Duration
// RetryAttempts sets the default number of retry attempts for client.
RetryAttempts int
// RetryDuration sets the delay duration for retries.
RetryDuration time.Duration
// UserAgent, if not empty, will be set as the HTTP User-Agent header on all requests sent
// through the Do method.
UserAgent string
Jar http.CookieJar
}
// NewClientWithUserAgent returns an instance of a Client with the UserAgent set to the passed
// string.
func NewClientWithUserAgent(ua string) Client {
c := Client{
PollingDelay: DefaultPollingDelay,
PollingDuration: DefaultPollingDuration,
RetryAttempts: DefaultRetryAttempts,
RetryDuration: DefaultRetryDuration,
UserAgent: defaultUserAgent,
}
c.Sender = c.sender()
c.AddToUserAgent(ua)
return c
}
// AddToUserAgent adds an extension to the current user agent
func (c *Client) AddToUserAgent(extension string) error {
if extension != "" {
c.UserAgent = fmt.Sprintf("%s %s", c.UserAgent, extension)
return nil
}
return fmt.Errorf("Extension was empty, User Agent stayed as %s", c.UserAgent)
}
// Do implements the Sender interface by invoking the active Sender after applying authorization.
// If Sender is not set, it uses a new instance of http.Client. In both cases it will, if UserAgent
// is set, apply set the User-Agent header.
func (c Client) Do(r *http.Request) (*http.Response, error) {
if r.UserAgent() == "" {
r, _ = Prepare(r,
WithUserAgent(c.UserAgent))
}
r, err := Prepare(r,
c.WithInspection(),
c.WithAuthorization())
if err != nil {
return nil, NewErrorWithError(err, "autorest/Client", "Do", nil, "Preparing request failed")
}
resp, err := SendWithSender(c.sender(), r)
Respond(resp, c.ByInspecting())
return resp, err
}
// sender returns the Sender to which to send requests.
func (c Client) sender() Sender {
if c.Sender == nil {
j, _ := cookiejar.New(nil)
return &http.Client{Jar: j}
}
return c.Sender
}
// WithAuthorization is a convenience method that returns the WithAuthorization PrepareDecorator
// from the current Authorizer. If not Authorizer is set, it uses the NullAuthorizer.
func (c Client) WithAuthorization() PrepareDecorator {
return c.authorizer().WithAuthorization()
}
// authorizer returns the Authorizer to use.
func (c Client) authorizer() Authorizer {
if c.Authorizer == nil {
return NullAuthorizer{}
}
return c.Authorizer
}
// WithInspection is a convenience method that passes the request to the supplied RequestInspector,
// if present, or returns the WithNothing PrepareDecorator otherwise.
func (c Client) WithInspection() PrepareDecorator {
if c.RequestInspector == nil {
return WithNothing()
}
return c.RequestInspector
}
// ByInspecting is a convenience method that passes the response to the supplied ResponseInspector,
// if present, or returns the ByIgnoring RespondDecorator otherwise.
func (c Client) ByInspecting() RespondDecorator {
if c.ResponseInspector == nil {
return ByIgnoring()
}
return c.ResponseInspector
}

View File

@@ -1,402 +0,0 @@
package autorest
// Copyright 2017 Microsoft Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
import (
"bytes"
"fmt"
"io/ioutil"
"log"
"math/rand"
"net/http"
"net/http/httptest"
"reflect"
"testing"
"time"
"github.com/Azure/go-autorest/autorest/mocks"
)
func TestLoggingInspectorWithInspection(t *testing.T) {
b := bytes.Buffer{}
c := Client{}
li := LoggingInspector{Logger: log.New(&b, "", 0)}
c.RequestInspector = li.WithInspection()
Prepare(mocks.NewRequestWithContent("Content"),
c.WithInspection())
if len(b.String()) <= 0 {
t.Fatal("autorest: LoggingInspector#WithInspection did not record Request to the log")
}
}
func TestLoggingInspectorWithInspectionEmitsErrors(t *testing.T) {
b := bytes.Buffer{}
c := Client{}
r := mocks.NewRequestWithContent("Content")
li := LoggingInspector{Logger: log.New(&b, "", 0)}
c.RequestInspector = li.WithInspection()
if _, err := Prepare(r,
c.WithInspection()); err != nil {
t.Error(err)
}
if len(b.String()) <= 0 {
t.Fatal("autorest: LoggingInspector#WithInspection did not record Request to the log")
}
}
func TestLoggingInspectorWithInspectionRestoresBody(t *testing.T) {
b := bytes.Buffer{}
c := Client{}
r := mocks.NewRequestWithContent("Content")
li := LoggingInspector{Logger: log.New(&b, "", 0)}
c.RequestInspector = li.WithInspection()
Prepare(r,
c.WithInspection())
s, _ := ioutil.ReadAll(r.Body)
if len(s) <= 0 {
t.Fatal("autorest: LoggingInspector#WithInspection did not restore the Request body")
}
}
func TestLoggingInspectorByInspecting(t *testing.T) {
b := bytes.Buffer{}
c := Client{}
li := LoggingInspector{Logger: log.New(&b, "", 0)}
c.ResponseInspector = li.ByInspecting()
Respond(mocks.NewResponseWithContent("Content"),
c.ByInspecting())
if len(b.String()) <= 0 {
t.Fatal("autorest: LoggingInspector#ByInspection did not record Response to the log")
}
}
func TestLoggingInspectorByInspectingEmitsErrors(t *testing.T) {
b := bytes.Buffer{}
c := Client{}
r := mocks.NewResponseWithContent("Content")
li := LoggingInspector{Logger: log.New(&b, "", 0)}
c.ResponseInspector = li.ByInspecting()
if err := Respond(r,
c.ByInspecting()); err != nil {
t.Fatal(err)
}
if len(b.String()) <= 0 {
t.Fatal("autorest: LoggingInspector#ByInspection did not record Response to the log")
}
}
func TestLoggingInspectorByInspectingRestoresBody(t *testing.T) {
b := bytes.Buffer{}
c := Client{}
r := mocks.NewResponseWithContent("Content")
li := LoggingInspector{Logger: log.New(&b, "", 0)}
c.ResponseInspector = li.ByInspecting()
Respond(r,
c.ByInspecting())
s, _ := ioutil.ReadAll(r.Body)
if len(s) <= 0 {
t.Fatal("autorest: LoggingInspector#ByInspecting did not restore the Response body")
}
}
func TestNewClientWithUserAgent(t *testing.T) {
ua := "UserAgent"
c := NewClientWithUserAgent(ua)
completeUA := fmt.Sprintf("%s %s", defaultUserAgent, ua)
if c.UserAgent != completeUA {
t.Fatalf("autorest: NewClientWithUserAgent failed to set the UserAgent -- expected %s, received %s",
completeUA, c.UserAgent)
}
}
func TestAddToUserAgent(t *testing.T) {
ua := "UserAgent"
c := NewClientWithUserAgent(ua)
ext := "extension"
err := c.AddToUserAgent(ext)
if err != nil {
t.Fatalf("autorest: AddToUserAgent returned error -- expected nil, received %s", err)
}
completeUA := fmt.Sprintf("%s %s %s", defaultUserAgent, ua, ext)
if c.UserAgent != completeUA {
t.Fatalf("autorest: AddToUserAgent failed to add an extension to the UserAgent -- expected %s, received %s",
completeUA, c.UserAgent)
}
err = c.AddToUserAgent("")
if err == nil {
t.Fatalf("autorest: AddToUserAgent didn't return error -- expected %s, received nil",
fmt.Errorf("Extension was empty, User Agent stayed as %s", c.UserAgent))
}
if c.UserAgent != completeUA {
t.Fatalf("autorest: AddToUserAgent failed to not add an empty extension to the UserAgent -- expected %s, received %s",
completeUA, c.UserAgent)
}
}
func TestClientSenderReturnsHttpClientByDefault(t *testing.T) {
c := Client{}
if fmt.Sprintf("%T", c.sender()) != "*http.Client" {
t.Fatal("autorest: Client#sender failed to return http.Client by default")
}
}
func TestClientSenderReturnsSetSender(t *testing.T) {
c := Client{}
s := mocks.NewSender()
c.Sender = s
if c.sender() != s {
t.Fatal("autorest: Client#sender failed to return set Sender")
}
}
func TestClientDoInvokesSender(t *testing.T) {
c := Client{}
s := mocks.NewSender()
c.Sender = s
c.Do(&http.Request{})
if s.Attempts() != 1 {
t.Fatal("autorest: Client#Do failed to invoke the Sender")
}
}
func TestClientDoSetsUserAgent(t *testing.T) {
ua := "UserAgent"
c := Client{UserAgent: ua}
r := mocks.NewRequest()
s := mocks.NewSender()
c.Sender = s
c.Do(r)
if r.UserAgent() != ua {
t.Fatalf("autorest: Client#Do failed to correctly set User-Agent header: %s=%s",
http.CanonicalHeaderKey(headerUserAgent), r.UserAgent())
}
}
func TestClientDoSetsAuthorization(t *testing.T) {
r := mocks.NewRequest()
s := mocks.NewSender()
c := Client{Authorizer: mockAuthorizer{}, Sender: s}
c.Do(r)
if len(r.Header.Get(http.CanonicalHeaderKey(headerAuthorization))) <= 0 {
t.Fatalf("autorest: Client#Send failed to set Authorization header -- %s=%s",
http.CanonicalHeaderKey(headerAuthorization),
r.Header.Get(http.CanonicalHeaderKey(headerAuthorization)))
}
}
func TestClientDoInvokesRequestInspector(t *testing.T) {
r := mocks.NewRequest()
s := mocks.NewSender()
i := &mockInspector{}
c := Client{RequestInspector: i.WithInspection(), Sender: s}
c.Do(r)
if !i.wasInvoked {
t.Fatal("autorest: Client#Send failed to invoke the RequestInspector")
}
}
func TestClientDoInvokesResponseInspector(t *testing.T) {
r := mocks.NewRequest()
s := mocks.NewSender()
i := &mockInspector{}
c := Client{ResponseInspector: i.ByInspecting(), Sender: s}
c.Do(r)
if !i.wasInvoked {
t.Fatal("autorest: Client#Send failed to invoke the ResponseInspector")
}
}
func TestClientDoReturnsErrorIfPrepareFails(t *testing.T) {
c := Client{}
s := mocks.NewSender()
c.Authorizer = mockFailingAuthorizer{}
c.Sender = s
_, err := c.Do(&http.Request{})
if err == nil {
t.Fatalf("autorest: Client#Do failed to return an error when Prepare failed")
}
}
func TestClientDoDoesNotSendIfPrepareFails(t *testing.T) {
c := Client{}
s := mocks.NewSender()
c.Authorizer = mockFailingAuthorizer{}
c.Sender = s
c.Do(&http.Request{})
if s.Attempts() > 0 {
t.Fatal("autorest: Client#Do failed to invoke the Sender")
}
}
func TestClientAuthorizerReturnsNullAuthorizerByDefault(t *testing.T) {
c := Client{}
if fmt.Sprintf("%T", c.authorizer()) != "autorest.NullAuthorizer" {
t.Fatal("autorest: Client#authorizer failed to return the NullAuthorizer by default")
}
}
func TestClientAuthorizerReturnsSetAuthorizer(t *testing.T) {
c := Client{}
c.Authorizer = mockAuthorizer{}
if fmt.Sprintf("%T", c.authorizer()) != "autorest.mockAuthorizer" {
t.Fatal("autorest: Client#authorizer failed to return the set Authorizer")
}
}
func TestClientWithAuthorizer(t *testing.T) {
c := Client{}
c.Authorizer = mockAuthorizer{}
req, _ := Prepare(&http.Request{},
c.WithAuthorization())
if req.Header.Get(headerAuthorization) == "" {
t.Fatal("autorest: Client#WithAuthorizer failed to return the WithAuthorizer from the active Authorizer")
}
}
func TestClientWithInspection(t *testing.T) {
c := Client{}
r := &mockInspector{}
c.RequestInspector = r.WithInspection()
Prepare(&http.Request{},
c.WithInspection())
if !r.wasInvoked {
t.Fatal("autorest: Client#WithInspection failed to invoke RequestInspector")
}
}
func TestClientWithInspectionSetsDefault(t *testing.T) {
c := Client{}
r1 := &http.Request{}
r2, _ := Prepare(r1,
c.WithInspection())
if !reflect.DeepEqual(r1, r2) {
t.Fatal("autorest: Client#WithInspection failed to provide a default RequestInspector")
}
}
func TestClientByInspecting(t *testing.T) {
c := Client{}
r := &mockInspector{}
c.ResponseInspector = r.ByInspecting()
Respond(&http.Response{},
c.ByInspecting())
if !r.wasInvoked {
t.Fatal("autorest: Client#ByInspecting failed to invoke ResponseInspector")
}
}
func TestClientByInspectingSetsDefault(t *testing.T) {
c := Client{}
r := &http.Response{}
Respond(r,
c.ByInspecting())
if !reflect.DeepEqual(r, &http.Response{}) {
t.Fatal("autorest: Client#ByInspecting failed to provide a default ResponseInspector")
}
}
func TestCookies(t *testing.T) {
second := "second"
expected := http.Cookie{
Name: "tastes",
Value: "delicious",
}
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
http.SetCookie(w, &expected)
b, err := ioutil.ReadAll(r.Body)
if err != nil {
t.Fatalf("autorest: ioutil.ReadAll failed reading request body: %s", err)
}
if string(b) == second {
cookie, err := r.Cookie(expected.Name)
if err != nil {
t.Fatalf("autorest: r.Cookie could not get request cookie: %s", err)
}
if cookie == nil {
t.Fatalf("autorest: got nil cookie, expecting %v", expected)
}
if cookie.Value != expected.Value {
t.Fatalf("autorest: got cookie value '%s', expecting '%s'", cookie.Value, expected.Name)
}
}
}))
defer server.Close()
client := NewClientWithUserAgent("")
_, err := SendWithSender(client, mocks.NewRequestForURL(server.URL))
if err != nil {
t.Fatalf("autorest: first request failed: %s", err)
}
r2, err := http.NewRequest(http.MethodGet, server.URL, mocks.NewBody(second))
if err != nil {
t.Fatalf("autorest: failed creating second request: %s", err)
}
_, err = SendWithSender(client, r2)
if err != nil {
t.Fatalf("autorest: second request failed: %s", err)
}
}
func randomString(n int) string {
const chars = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789"
r := rand.New(rand.NewSource(time.Now().UTC().UnixNano()))
s := make([]byte, n)
for i := range s {
s[i] = chars[r.Intn(len(chars))]
}
return string(s)
}

View File

@@ -1,237 +0,0 @@
package date
// Copyright 2017 Microsoft Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
import (
"encoding/json"
"fmt"
"reflect"
"testing"
"time"
)
func ExampleParseDate() {
d, err := ParseDate("2001-02-03")
if err != nil {
fmt.Println(err)
}
fmt.Println(d)
// Output: 2001-02-03
}
func ExampleDate() {
d, err := ParseDate("2001-02-03")
if err != nil {
fmt.Println(err)
}
t, err := time.Parse(time.RFC3339, "2001-02-04T00:00:00Z")
if err != nil {
fmt.Println(err)
}
// Date acts as time.Time when the receiver
if d.Before(t) {
fmt.Printf("Before ")
} else {
fmt.Printf("After ")
}
// Convert Date when needing a time.Time
if t.After(d.ToTime()) {
fmt.Printf("After")
} else {
fmt.Printf("Before")
}
// Output: Before After
}
func ExampleDate_MarshalBinary() {
d, err := ParseDate("2001-02-03")
if err != nil {
fmt.Println(err)
}
t, err := d.MarshalBinary()
if err != nil {
fmt.Println(err)
}
fmt.Println(string(t))
// Output: 2001-02-03
}
func ExampleDate_UnmarshalBinary() {
d := Date{}
t := "2001-02-03"
if err := d.UnmarshalBinary([]byte(t)); err != nil {
fmt.Println(err)
}
fmt.Println(d)
// Output: 2001-02-03
}
func ExampleDate_MarshalJSON() {
d, err := ParseDate("2001-02-03")
if err != nil {
fmt.Println(err)
}
j, err := json.Marshal(d)
if err != nil {
fmt.Println(err)
}
fmt.Println(string(j))
// Output: "2001-02-03"
}
func ExampleDate_UnmarshalJSON() {
var d struct {
Date Date `json:"date"`
}
j := `{"date" : "2001-02-03"}`
if err := json.Unmarshal([]byte(j), &d); err != nil {
fmt.Println(err)
}
fmt.Println(d.Date)
// Output: 2001-02-03
}
func ExampleDate_MarshalText() {
d, err := ParseDate("2001-02-03")
if err != nil {
fmt.Println(err)
}
t, err := d.MarshalText()
if err != nil {
fmt.Println(err)
}
fmt.Println(string(t))
// Output: 2001-02-03
}
func ExampleDate_UnmarshalText() {
d := Date{}
t := "2001-02-03"
if err := d.UnmarshalText([]byte(t)); err != nil {
fmt.Println(err)
}
fmt.Println(d)
// Output: 2001-02-03
}
func TestDateString(t *testing.T) {
d, err := ParseDate("2001-02-03")
if err != nil {
t.Fatalf("date: String failed (%v)", err)
}
if d.String() != "2001-02-03" {
t.Fatalf("date: String failed (%v)", d.String())
}
}
func TestDateBinaryRoundTrip(t *testing.T) {
d1, err := ParseDate("2001-02-03")
if err != nil {
t.Fatalf("date: ParseDate failed (%v)", err)
}
t1, err := d1.MarshalBinary()
if err != nil {
t.Fatalf("date: MarshalBinary failed (%v)", err)
}
d2 := Date{}
if err = d2.UnmarshalBinary(t1); err != nil {
t.Fatalf("date: UnmarshalBinary failed (%v)", err)
}
if !reflect.DeepEqual(d1, d2) {
t.Fatalf("date: Round-trip Binary failed (%v, %v)", d1, d2)
}
}
func TestDateJSONRoundTrip(t *testing.T) {
type s struct {
Date Date `json:"date"`
}
var err error
d1 := s{}
d1.Date, err = ParseDate("2001-02-03")
if err != nil {
t.Fatalf("date: ParseDate failed (%v)", err)
}
j, err := json.Marshal(d1)
if err != nil {
t.Fatalf("date: MarshalJSON failed (%v)", err)
}
d2 := s{}
if err = json.Unmarshal(j, &d2); err != nil {
t.Fatalf("date: UnmarshalJSON failed (%v)", err)
}
if !reflect.DeepEqual(d1, d2) {
t.Fatalf("date: Round-trip JSON failed (%v, %v)", d1, d2)
}
}
func TestDateTextRoundTrip(t *testing.T) {
d1, err := ParseDate("2001-02-03")
if err != nil {
t.Fatalf("date: ParseDate failed (%v)", err)
}
t1, err := d1.MarshalText()
if err != nil {
t.Fatalf("date: MarshalText failed (%v)", err)
}
d2 := Date{}
if err = d2.UnmarshalText(t1); err != nil {
t.Fatalf("date: UnmarshalText failed (%v)", err)
}
if !reflect.DeepEqual(d1, d2) {
t.Fatalf("date: Round-trip Text failed (%v, %v)", d1, d2)
}
}
func TestDateToTime(t *testing.T) {
var d Date
d, err := ParseDate("2001-02-03")
if err != nil {
t.Fatalf("date: ParseDate failed (%v)", err)
}
var _ time.Time = d.ToTime()
}
func TestDateUnmarshalJSONReturnsError(t *testing.T) {
var d struct {
Date Date `json:"date"`
}
j := `{"date" : "February 3, 2001"}`
if err := json.Unmarshal([]byte(j), &d); err == nil {
t.Fatal("date: Date failed to return error for malformed JSON date")
}
}
func TestDateUnmarshalTextReturnsError(t *testing.T) {
d := Date{}
txt := "February 3, 2001"
if err := d.UnmarshalText([]byte(txt)); err == nil {
t.Fatal("date: Date failed to return error for malformed Text date")
}
}

View File

@@ -1,277 +0,0 @@
package date
// Copyright 2017 Microsoft Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
import (
"encoding/json"
"fmt"
"reflect"
"testing"
"time"
)
func ExampleParseTime() {
d, _ := ParseTime(rfc3339, "2001-02-03T04:05:06Z")
fmt.Println(d)
// Output: 2001-02-03 04:05:06 +0000 UTC
}
func ExampleTime_MarshalBinary() {
ti, err := ParseTime(rfc3339, "2001-02-03T04:05:06Z")
if err != nil {
fmt.Println(err)
}
d := Time{ti}
t, err := d.MarshalBinary()
if err != nil {
fmt.Println(err)
}
fmt.Println(string(t))
// Output: 2001-02-03T04:05:06Z
}
func ExampleTime_UnmarshalBinary() {
d := Time{}
t := "2001-02-03T04:05:06Z"
if err := d.UnmarshalBinary([]byte(t)); err != nil {
fmt.Println(err)
}
fmt.Println(d)
// Output: 2001-02-03T04:05:06Z
}
func ExampleTime_MarshalJSON() {
d, err := ParseTime(rfc3339, "2001-02-03T04:05:06Z")
if err != nil {
fmt.Println(err)
}
j, err := json.Marshal(d)
if err != nil {
fmt.Println(err)
}
fmt.Println(string(j))
// Output: "2001-02-03T04:05:06Z"
}
func ExampleTime_UnmarshalJSON() {
var d struct {
Time Time `json:"datetime"`
}
j := `{"datetime" : "2001-02-03T04:05:06Z"}`
if err := json.Unmarshal([]byte(j), &d); err != nil {
fmt.Println(err)
}
fmt.Println(d.Time)
// Output: 2001-02-03T04:05:06Z
}
func ExampleTime_MarshalText() {
d, err := ParseTime(rfc3339, "2001-02-03T04:05:06Z")
if err != nil {
fmt.Println(err)
}
t, err := d.MarshalText()
if err != nil {
fmt.Println(err)
}
fmt.Println(string(t))
// Output: 2001-02-03T04:05:06Z
}
func ExampleTime_UnmarshalText() {
d := Time{}
t := "2001-02-03T04:05:06Z"
if err := d.UnmarshalText([]byte(t)); err != nil {
fmt.Println(err)
}
fmt.Println(d)
// Output: 2001-02-03T04:05:06Z
}
func TestUnmarshalTextforInvalidDate(t *testing.T) {
d := Time{}
dt := "2001-02-03T04:05:06AAA"
if err := d.UnmarshalText([]byte(dt)); err == nil {
t.Fatalf("date: Time#Unmarshal was expecting error for invalid date")
}
}
func TestUnmarshalJSONforInvalidDate(t *testing.T) {
d := Time{}
dt := `"2001-02-03T04:05:06AAA"`
if err := d.UnmarshalJSON([]byte(dt)); err == nil {
t.Fatalf("date: Time#Unmarshal was expecting error for invalid date")
}
}
func TestTimeString(t *testing.T) {
ti, err := ParseTime(rfc3339, "2001-02-03T04:05:06Z")
if err != nil {
fmt.Println(err)
}
d := Time{ti}
if d.String() != "2001-02-03T04:05:06Z" {
t.Fatalf("date: Time#String failed (%v)", d.String())
}
}
func TestTimeStringReturnsEmptyStringForError(t *testing.T) {
d := Time{Time: time.Date(20000, 01, 01, 01, 01, 01, 01, time.UTC)}
if d.String() != "" {
t.Fatalf("date: Time#String failed empty string for an error")
}
}
func TestTimeBinaryRoundTrip(t *testing.T) {
ti, err := ParseTime(rfc3339, "2001-02-03T04:05:06Z")
if err != nil {
t.Fatalf("date: Time#ParseTime failed (%v)", err)
}
d1 := Time{ti}
t1, err := d1.MarshalBinary()
if err != nil {
t.Fatalf("date: Time#MarshalBinary failed (%v)", err)
}
d2 := Time{}
if err = d2.UnmarshalBinary(t1); err != nil {
t.Fatalf("date: Time#UnmarshalBinary failed (%v)", err)
}
if !reflect.DeepEqual(d1, d2) {
t.Fatalf("date:Round-trip Binary failed (%v, %v)", d1, d2)
}
}
func TestTimeJSONRoundTrip(t *testing.T) {
type s struct {
Time Time `json:"datetime"`
}
ti, err := ParseTime(rfc3339, "2001-02-03T04:05:06Z")
if err != nil {
t.Fatalf("date: Time#ParseTime failed (%v)", err)
}
d1 := s{Time: Time{ti}}
j, err := json.Marshal(d1)
if err != nil {
t.Fatalf("date: Time#MarshalJSON failed (%v)", err)
}
d2 := s{}
if err = json.Unmarshal(j, &d2); err != nil {
t.Fatalf("date: Time#UnmarshalJSON failed (%v)", err)
}
if !reflect.DeepEqual(d1, d2) {
t.Fatalf("date: Round-trip JSON failed (%v, %v)", d1, d2)
}
}
func TestTimeTextRoundTrip(t *testing.T) {
ti, err := ParseTime(rfc3339, "2001-02-03T04:05:06Z")
if err != nil {
t.Fatalf("date: Time#ParseTime failed (%v)", err)
}
d1 := Time{Time: ti}
t1, err := d1.MarshalText()
if err != nil {
t.Fatalf("date: Time#MarshalText failed (%v)", err)
}
d2 := Time{}
if err = d2.UnmarshalText(t1); err != nil {
t.Fatalf("date: Time#UnmarshalText failed (%v)", err)
}
if !reflect.DeepEqual(d1, d2) {
t.Fatalf("date: Round-trip Text failed (%v, %v)", d1, d2)
}
}
func TestTimeToTime(t *testing.T) {
ti, err := ParseTime(rfc3339, "2001-02-03T04:05:06Z")
d := Time{ti}
if err != nil {
t.Fatalf("date: Time#ParseTime failed (%v)", err)
}
var _ time.Time = d.ToTime()
}
func TestUnmarshalJSONNoOffset(t *testing.T) {
var d struct {
Time Time `json:"datetime"`
}
j := `{"datetime" : "2001-02-03T04:05:06.789"}`
if err := json.Unmarshal([]byte(j), &d); err != nil {
t.Fatalf("date: Time#Unmarshal failed (%v)", err)
}
}
func TestUnmarshalJSONPosOffset(t *testing.T) {
var d struct {
Time Time `json:"datetime"`
}
j := `{"datetime" : "1980-01-02T00:11:35.01+01:00"}`
if err := json.Unmarshal([]byte(j), &d); err != nil {
t.Fatalf("date: Time#Unmarshal failed (%v)", err)
}
}
func TestUnmarshalJSONNegOffset(t *testing.T) {
var d struct {
Time Time `json:"datetime"`
}
j := `{"datetime" : "1492-10-12T10:15:01.789-08:00"}`
if err := json.Unmarshal([]byte(j), &d); err != nil {
t.Fatalf("date: Time#Unmarshal failed (%v)", err)
}
}
func TestUnmarshalTextNoOffset(t *testing.T) {
d := Time{}
t1 := "2001-02-03T04:05:06"
if err := d.UnmarshalText([]byte(t1)); err != nil {
t.Fatalf("date: Time#UnmarshalText failed (%v)", err)
}
}
func TestUnmarshalTextPosOffset(t *testing.T) {
d := Time{}
t1 := "2001-02-03T04:05:06+00:30"
if err := d.UnmarshalText([]byte(t1)); err != nil {
t.Fatalf("date: Time#UnmarshalText failed (%v)", err)
}
}
func TestUnmarshalTextNegOffset(t *testing.T) {
d := Time{}
t1 := "2001-02-03T04:05:06-11:00"
if err := d.UnmarshalText([]byte(t1)); err != nil {
t.Fatalf("date: Time#UnmarshalText failed (%v)", err)
}
}

View File

@@ -1,226 +0,0 @@
package date
// Copyright 2017 Microsoft Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
import (
"encoding/json"
"fmt"
"reflect"
"testing"
"time"
)
func ExampleTimeRFC1123() {
d, err := ParseTime(rfc1123, "Mon, 02 Jan 2006 15:04:05 MST")
if err != nil {
fmt.Println(err)
}
fmt.Println(d)
// Output: 2006-01-02 15:04:05 +0000 MST
}
func ExampleTimeRFC1123_MarshalBinary() {
ti, err := ParseTime(rfc1123, "Mon, 02 Jan 2006 15:04:05 MST")
if err != nil {
fmt.Println(err)
}
d := TimeRFC1123{ti}
b, err := d.MarshalBinary()
if err != nil {
fmt.Println(err)
}
fmt.Println(string(b))
// Output: Mon, 02 Jan 2006 15:04:05 MST
}
func ExampleTimeRFC1123_UnmarshalBinary() {
d := TimeRFC1123{}
t := "Mon, 02 Jan 2006 15:04:05 MST"
if err := d.UnmarshalBinary([]byte(t)); err != nil {
fmt.Println(err)
}
fmt.Println(d)
// Output: Mon, 02 Jan 2006 15:04:05 MST
}
func ExampleTimeRFC1123_MarshalJSON() {
ti, err := ParseTime(rfc1123, "Mon, 02 Jan 2006 15:04:05 MST")
if err != nil {
fmt.Println(err)
}
d := TimeRFC1123{ti}
j, err := json.Marshal(d)
if err != nil {
fmt.Println(err)
}
fmt.Println(string(j))
// Output: "Mon, 02 Jan 2006 15:04:05 MST"
}
func TestTimeRFC1123MarshalJSONInvalid(t *testing.T) {
ti := time.Date(20000, 01, 01, 00, 00, 00, 00, time.UTC)
d := TimeRFC1123{ti}
if _, err := json.Marshal(d); err == nil {
t.Fatalf("date: TimeRFC1123#Marshal failed for invalid date")
}
}
func ExampleTimeRFC1123_UnmarshalJSON() {
var d struct {
Time TimeRFC1123 `json:"datetime"`
}
j := `{"datetime" : "Mon, 02 Jan 2006 15:04:05 MST"}`
if err := json.Unmarshal([]byte(j), &d); err != nil {
fmt.Println(err)
}
fmt.Println(d.Time)
// Output: Mon, 02 Jan 2006 15:04:05 MST
}
func ExampleTimeRFC1123_MarshalText() {
ti, err := ParseTime(rfc3339, "2001-02-03T04:05:06Z")
if err != nil {
fmt.Println(err)
}
d := TimeRFC1123{ti}
t, err := d.MarshalText()
if err != nil {
fmt.Println(err)
}
fmt.Println(string(t))
// Output: Sat, 03 Feb 2001 04:05:06 UTC
}
func ExampleTimeRFC1123_UnmarshalText() {
d := TimeRFC1123{}
t := "Sat, 03 Feb 2001 04:05:06 UTC"
if err := d.UnmarshalText([]byte(t)); err != nil {
fmt.Println(err)
}
fmt.Println(d)
// Output: Sat, 03 Feb 2001 04:05:06 UTC
}
func TestUnmarshalJSONforInvalidDateRfc1123(t *testing.T) {
dt := `"Mon, 02 Jan 2000000 15:05 MST"`
d := TimeRFC1123{}
if err := d.UnmarshalJSON([]byte(dt)); err == nil {
t.Fatalf("date: TimeRFC1123#Unmarshal failed for invalid date")
}
}
func TestUnmarshalTextforInvalidDateRfc1123(t *testing.T) {
dt := "Mon, 02 Jan 2000000 15:05 MST"
d := TimeRFC1123{}
if err := d.UnmarshalText([]byte(dt)); err == nil {
t.Fatalf("date: TimeRFC1123#Unmarshal failed for invalid date")
}
}
func TestTimeStringRfc1123(t *testing.T) {
ti, err := ParseTime(rfc1123, "Mon, 02 Jan 2006 15:04:05 MST")
if err != nil {
fmt.Println(err)
}
d := TimeRFC1123{ti}
if d.String() != "Mon, 02 Jan 2006 15:04:05 MST" {
t.Fatalf("date: TimeRFC1123#String failed (%v)", d.String())
}
}
func TestTimeStringReturnsEmptyStringForErrorRfc1123(t *testing.T) {
d := TimeRFC1123{Time: time.Date(20000, 01, 01, 01, 01, 01, 01, time.UTC)}
if d.String() != "" {
t.Fatalf("date: TimeRFC1123#String failed empty string for an error")
}
}
func TestTimeBinaryRoundTripRfc1123(t *testing.T) {
ti, err := ParseTime(rfc3339, "2001-02-03T04:05:06Z")
if err != nil {
t.Fatalf("date: TimeRFC1123#ParseTime failed (%v)", err)
}
d1 := TimeRFC1123{ti}
t1, err := d1.MarshalBinary()
if err != nil {
t.Fatalf("date: TimeRFC1123#MarshalBinary failed (%v)", err)
}
d2 := TimeRFC1123{}
if err = d2.UnmarshalBinary(t1); err != nil {
t.Fatalf("date: TimeRFC1123#UnmarshalBinary failed (%v)", err)
}
if !reflect.DeepEqual(d1, d2) {
t.Fatalf("date: Round-trip Binary failed (%v, %v)", d1, d2)
}
}
func TestTimeJSONRoundTripRfc1123(t *testing.T) {
type s struct {
Time TimeRFC1123 `json:"datetime"`
}
var err error
ti, err := ParseTime(rfc1123, "Mon, 02 Jan 2006 15:04:05 MST")
if err != nil {
t.Fatalf("date: TimeRFC1123#ParseTime failed (%v)", err)
}
d1 := s{Time: TimeRFC1123{ti}}
j, err := json.Marshal(d1)
if err != nil {
t.Fatalf("date: TimeRFC1123#MarshalJSON failed (%v)", err)
}
d2 := s{}
if err = json.Unmarshal(j, &d2); err != nil {
t.Fatalf("date: TimeRFC1123#UnmarshalJSON failed (%v)", err)
}
if !reflect.DeepEqual(d1, d2) {
t.Fatalf("date: Round-trip JSON failed (%v, %v)", d1, d2)
}
}
func TestTimeTextRoundTripRfc1123(t *testing.T) {
ti, err := ParseTime(rfc1123, "Mon, 02 Jan 2006 15:04:05 MST")
if err != nil {
t.Fatalf("date: TimeRFC1123#ParseTime failed (%v)", err)
}
d1 := TimeRFC1123{Time: ti}
t1, err := d1.MarshalText()
if err != nil {
t.Fatalf("date: TimeRFC1123#MarshalText failed (%v)", err)
}
d2 := TimeRFC1123{}
if err = d2.UnmarshalText(t1); err != nil {
t.Fatalf("date: TimeRFC1123#UnmarshalText failed (%v)", err)
}
if !reflect.DeepEqual(d1, d2) {
t.Fatalf("date: Round-trip Text failed (%v, %v)", d1, d2)
}
}
func TestTimeToTimeRFC1123(t *testing.T) {
ti, err := ParseTime(rfc1123, "Mon, 02 Jan 2006 15:04:05 MST")
d := TimeRFC1123{ti}
if err != nil {
t.Fatalf("date: TimeRFC1123#ParseTime failed (%v)", err)
}
var _ time.Time = d.ToTime()
}

View File

@@ -1,283 +0,0 @@
// +build go1.7
package date
// Copyright 2017 Microsoft Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
import (
"bytes"
"encoding/binary"
"encoding/json"
"fmt"
"math"
"testing"
"time"
)
func ExampleUnixTime_MarshalJSON() {
epoch := UnixTime(UnixEpoch())
text, _ := json.Marshal(epoch)
fmt.Print(string(text))
// Output: 0
}
func ExampleUnixTime_UnmarshalJSON() {
var myTime UnixTime
json.Unmarshal([]byte("1.3e2"), &myTime)
fmt.Printf("%v", time.Time(myTime))
// Output: 1970-01-01 00:02:10 +0000 UTC
}
func TestUnixTime_MarshalJSON(t *testing.T) {
testCases := []time.Time{
UnixEpoch().Add(-1 * time.Second), // One second befote the Unix Epoch
time.Date(2017, time.April, 14, 20, 27, 47, 0, time.UTC), // The time this test was written
UnixEpoch(),
time.Date(1800, 01, 01, 0, 0, 0, 0, time.UTC),
time.Date(2200, 12, 29, 00, 01, 37, 82, time.UTC),
}
for _, tc := range testCases {
t.Run(tc.String(), func(subT *testing.T) {
var actual, expected float64
var marshaled []byte
target := UnixTime(tc)
expected = float64(target.Duration().Nanoseconds()) / 1e9
if temp, err := json.Marshal(target); err == nil {
marshaled = temp
} else {
subT.Error(err)
return
}
dec := json.NewDecoder(bytes.NewReader(marshaled))
if err := dec.Decode(&actual); err != nil {
subT.Error(err)
return
}
diff := math.Abs(actual - expected)
subT.Logf("\ngot :\t%g\nwant:\t%g\ndiff:\t%g", actual, expected, diff)
if diff > 1e-9 { //Must be within 1 nanosecond of one another
subT.Fail()
}
})
}
}
func TestUnixTime_UnmarshalJSON(t *testing.T) {
testCases := []struct {
text string
expected time.Time
}{
{"1", UnixEpoch().Add(time.Second)},
{"0", UnixEpoch()},
{"1492203742", time.Date(2017, time.April, 14, 21, 02, 22, 0, time.UTC)}, // The time this test was written
{"-1", time.Date(1969, time.December, 31, 23, 59, 59, 0, time.UTC)},
{"1.5", UnixEpoch().Add(1500 * time.Millisecond)},
{"0e1", UnixEpoch()}, // See http://json.org for 'number' format definition.
{"1.3e+2", UnixEpoch().Add(130 * time.Second)},
{"1.6E-10", UnixEpoch()}, // This is so small, it should get truncated into the UnixEpoch
{"2E-6", UnixEpoch().Add(2 * time.Microsecond)},
{"1.289345e9", UnixEpoch().Add(1289345000 * time.Second)},
{"1e-9", UnixEpoch().Add(time.Nanosecond)},
}
for _, tc := range testCases {
t.Run(tc.text, func(subT *testing.T) {
var rehydrated UnixTime
if err := json.Unmarshal([]byte(tc.text), &rehydrated); err != nil {
subT.Error(err)
return
}
if time.Time(rehydrated) != tc.expected {
subT.Logf("\ngot: \t%v\nwant:\t%v\ndiff:\t%v", time.Time(rehydrated), tc.expected, time.Time(rehydrated).Sub(tc.expected))
subT.Fail()
}
})
}
}
func TestUnixTime_JSONRoundTrip(t *testing.T) {
testCases := []time.Time{
UnixEpoch(),
time.Date(2005, time.November, 5, 0, 0, 0, 0, time.UTC), // The day V for Vendetta (film) was released.
UnixEpoch().Add(-6 * time.Second),
UnixEpoch().Add(800 * time.Hour),
UnixEpoch().Add(time.Nanosecond),
time.Date(2015, time.September, 05, 4, 30, 12, 9992, time.UTC),
}
for _, tc := range testCases {
t.Run(tc.String(), func(subT *testing.T) {
subject := UnixTime(tc)
var marshaled []byte
if temp, err := json.Marshal(subject); err == nil {
marshaled = temp
} else {
subT.Error(err)
return
}
var unmarshaled UnixTime
if err := json.Unmarshal(marshaled, &unmarshaled); err != nil {
subT.Error(err)
}
actual := time.Time(unmarshaled)
diff := actual.Sub(tc)
subT.Logf("\ngot :\t%s\nwant:\t%s\ndiff:\t%s", actual.String(), tc.String(), diff.String())
if diff > time.Duration(100) { // We lose some precision be working in floats. We shouldn't lose more than 100 nanoseconds.
subT.Fail()
}
})
}
}
func TestUnixTime_MarshalBinary(t *testing.T) {
testCases := []struct {
expected int64
subject time.Time
}{
{0, UnixEpoch()},
{-15 * int64(time.Second), UnixEpoch().Add(-15 * time.Second)},
{54, UnixEpoch().Add(54 * time.Nanosecond)},
}
for _, tc := range testCases {
t.Run("", func(subT *testing.T) {
var marshaled []byte
if temp, err := UnixTime(tc.subject).MarshalBinary(); err == nil {
marshaled = temp
} else {
subT.Error(err)
return
}
var unmarshaled int64
if err := binary.Read(bytes.NewReader(marshaled), binary.LittleEndian, &unmarshaled); err != nil {
subT.Error(err)
return
}
if unmarshaled != tc.expected {
subT.Logf("\ngot: \t%d\nwant:\t%d", unmarshaled, tc.expected)
subT.Fail()
}
})
}
}
func TestUnixTime_BinaryRoundTrip(t *testing.T) {
testCases := []time.Time{
UnixEpoch(),
UnixEpoch().Add(800 * time.Minute),
UnixEpoch().Add(7 * time.Hour),
UnixEpoch().Add(-1 * time.Nanosecond),
}
for _, tc := range testCases {
t.Run(tc.String(), func(subT *testing.T) {
original := UnixTime(tc)
var marshaled []byte
if temp, err := original.MarshalBinary(); err == nil {
marshaled = temp
} else {
subT.Error(err)
return
}
var traveled UnixTime
if err := traveled.UnmarshalBinary(marshaled); err != nil {
subT.Error(err)
return
}
if traveled != original {
subT.Logf("\ngot: \t%s\nwant:\t%s", time.Time(original).String(), time.Time(traveled).String())
subT.Fail()
}
})
}
}
func TestUnixTime_MarshalText(t *testing.T) {
testCases := []time.Time{
UnixEpoch(),
UnixEpoch().Add(45 * time.Second),
UnixEpoch().Add(time.Nanosecond),
UnixEpoch().Add(-100000 * time.Second),
}
for _, tc := range testCases {
expected, _ := tc.MarshalText()
t.Run("", func(subT *testing.T) {
var marshaled []byte
if temp, err := UnixTime(tc).MarshalText(); err == nil {
marshaled = temp
} else {
subT.Error(err)
return
}
if string(marshaled) != string(expected) {
subT.Logf("\ngot: \t%s\nwant:\t%s", string(marshaled), string(expected))
subT.Fail()
}
})
}
}
func TestUnixTime_TextRoundTrip(t *testing.T) {
testCases := []time.Time{
UnixEpoch(),
UnixEpoch().Add(-1 * time.Nanosecond),
UnixEpoch().Add(1 * time.Nanosecond),
time.Date(2017, time.April, 17, 21, 00, 00, 00, time.UTC),
}
for _, tc := range testCases {
t.Run(tc.String(), func(subT *testing.T) {
unixTC := UnixTime(tc)
var marshaled []byte
if temp, err := unixTC.MarshalText(); err == nil {
marshaled = temp
} else {
subT.Error(err)
return
}
var unmarshaled UnixTime
if err := unmarshaled.UnmarshalText(marshaled); err != nil {
subT.Error(err)
return
}
if unmarshaled != unixTC {
t.Logf("\ngot: \t%s\nwant:\t%s", time.Time(unmarshaled).String(), tc.String())
t.Fail()
}
})
}
}

View File

@@ -1,98 +0,0 @@
package autorest
// Copyright 2017 Microsoft Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
import (
"fmt"
"net/http"
)
const (
// UndefinedStatusCode is used when HTTP status code is not available for an error.
UndefinedStatusCode = 0
)
// DetailedError encloses a error with details of the package, method, and associated HTTP
// status code (if any).
type DetailedError struct {
Original error
// PackageType is the package type of the object emitting the error. For types, the value
// matches that produced the the '%T' format specifier of the fmt package. For other elements,
// such as functions, it is just the package name (e.g., "autorest").
PackageType string
// Method is the name of the method raising the error.
Method string
// StatusCode is the HTTP Response StatusCode (if non-zero) that led to the error.
StatusCode interface{}
// Message is the error message.
Message string
// Service Error is the response body of failed API in bytes
ServiceError []byte
// Response is the response object that was returned during failure if applicable.
Response *http.Response
}
// NewError creates a new Error conforming object from the passed packageType, method, and
// message. message is treated as a format string to which the optional args apply.
func NewError(packageType string, method string, message string, args ...interface{}) DetailedError {
return NewErrorWithError(nil, packageType, method, nil, message, args...)
}
// NewErrorWithResponse creates a new Error conforming object from the passed
// packageType, method, statusCode of the given resp (UndefinedStatusCode if
// resp is nil), and message. message is treated as a format string to which the
// optional args apply.
func NewErrorWithResponse(packageType string, method string, resp *http.Response, message string, args ...interface{}) DetailedError {
return NewErrorWithError(nil, packageType, method, resp, message, args...)
}
// NewErrorWithError creates a new Error conforming object from the
// passed packageType, method, statusCode of the given resp (UndefinedStatusCode
// if resp is nil), message, and original error. message is treated as a format
// string to which the optional args apply.
func NewErrorWithError(original error, packageType string, method string, resp *http.Response, message string, args ...interface{}) DetailedError {
if v, ok := original.(DetailedError); ok {
return v
}
statusCode := UndefinedStatusCode
if resp != nil {
statusCode = resp.StatusCode
}
return DetailedError{
Original: original,
PackageType: packageType,
Method: method,
StatusCode: statusCode,
Message: fmt.Sprintf(message, args...),
Response: resp,
}
}
// Error returns a formatted containing all available details (i.e., PackageType, Method,
// StatusCode, Message, and original error (if any)).
func (e DetailedError) Error() string {
if e.Original == nil {
return fmt.Sprintf("%s#%s: %s: StatusCode=%d", e.PackageType, e.Method, e.Message, e.StatusCode)
}
return fmt.Sprintf("%s#%s: %s: StatusCode=%d -- Original Error: %v", e.PackageType, e.Method, e.Message, e.StatusCode, e.Original)
}

View File

@@ -1,202 +0,0 @@
package autorest
// Copyright 2017 Microsoft Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
import (
"fmt"
"net/http"
"reflect"
"regexp"
"testing"
)
func TestNewErrorWithError_AssignsPackageType(t *testing.T) {
e := NewErrorWithError(fmt.Errorf("original"), "packageType", "method", nil, "message")
if e.PackageType != "packageType" {
t.Fatalf("autorest: Error failed to set package type -- expected %v, received %v", "packageType", e.PackageType)
}
}
func TestNewErrorWithError_AssignsMethod(t *testing.T) {
e := NewErrorWithError(fmt.Errorf("original"), "packageType", "method", nil, "message")
if e.Method != "method" {
t.Fatalf("autorest: Error failed to set method -- expected %v, received %v", "method", e.Method)
}
}
func TestNewErrorWithError_AssignsMessage(t *testing.T) {
e := NewErrorWithError(fmt.Errorf("original"), "packageType", "method", nil, "message")
if e.Message != "message" {
t.Fatalf("autorest: Error failed to set message -- expected %v, received %v", "message", e.Message)
}
}
func TestNewErrorWithError_AssignsUndefinedStatusCodeIfRespNil(t *testing.T) {
e := NewErrorWithError(nil, "packageType", "method", nil, "message")
if e.StatusCode != UndefinedStatusCode {
t.Fatalf("autorest: Error failed to set status code -- expected %v, received %v", UndefinedStatusCode, e.StatusCode)
}
}
func TestNewErrorWithError_AssignsStatusCode(t *testing.T) {
e := NewErrorWithError(fmt.Errorf("original"), "packageType", "method", &http.Response{
StatusCode: http.StatusBadRequest,
Status: http.StatusText(http.StatusBadRequest)}, "message")
if e.StatusCode != http.StatusBadRequest {
t.Fatalf("autorest: Error failed to set status code -- expected %v, received %v", http.StatusBadRequest, e.StatusCode)
}
}
func TestNewErrorWithError_AcceptsArgs(t *testing.T) {
e := NewErrorWithError(fmt.Errorf("original"), "packageType", "method", nil, "message %s", "arg")
if matched, _ := regexp.MatchString(`.*arg.*`, e.Message); !matched {
t.Fatalf("autorest: Error failed to apply message arguments -- expected %v, received %v",
`.*arg.*`, e.Message)
}
}
func TestNewErrorWithError_AssignsError(t *testing.T) {
err := fmt.Errorf("original")
e := NewErrorWithError(err, "packageType", "method", nil, "message")
if e.Original != err {
t.Fatalf("autorest: Error failed to set error -- expected %v, received %v", err, e.Original)
}
}
func TestNewErrorWithResponse_ContainsStatusCode(t *testing.T) {
e := NewErrorWithResponse("packageType", "method", &http.Response{
StatusCode: http.StatusBadRequest,
Status: http.StatusText(http.StatusBadRequest)}, "message")
if e.StatusCode != http.StatusBadRequest {
t.Fatalf("autorest: Error failed to set status code -- expected %v, received %v", http.StatusBadRequest, e.StatusCode)
}
}
func TestNewErrorWithResponse_nilResponse_ReportsUndefinedStatusCode(t *testing.T) {
e := NewErrorWithResponse("packageType", "method", nil, "message")
if e.StatusCode != UndefinedStatusCode {
t.Fatalf("autorest: Error failed to set status code -- expected %v, received %v", UndefinedStatusCode, e.StatusCode)
}
}
func TestNewErrorWithResponse_Forwards(t *testing.T) {
e1 := NewError("packageType", "method", "message %s", "arg")
e2 := NewErrorWithResponse("packageType", "method", nil, "message %s", "arg")
if !reflect.DeepEqual(e1, e2) {
t.Fatal("autorest: NewError did not return an error equivelent to NewErrorWithError")
}
}
func TestNewErrorWithError_Forwards(t *testing.T) {
e1 := NewError("packageType", "method", "message %s", "arg")
e2 := NewErrorWithError(nil, "packageType", "method", nil, "message %s", "arg")
if !reflect.DeepEqual(e1, e2) {
t.Fatal("autorest: NewError did not return an error equivelent to NewErrorWithError")
}
}
func TestNewErrorWithError_DoesNotWrapADetailedError(t *testing.T) {
e1 := NewError("packageType1", "method1", "message1 %s", "arg1")
e2 := NewErrorWithError(e1, "packageType2", "method2", nil, "message2 %s", "arg2")
if !reflect.DeepEqual(e1, e2) {
t.Fatalf("autorest: NewErrorWithError incorrectly wrapped a DetailedError -- expected %v, received %v", e1, e2)
}
}
func TestNewErrorWithError_WrapsAnError(t *testing.T) {
e1 := fmt.Errorf("Inner Error")
var e2 interface{} = NewErrorWithError(e1, "packageType", "method", nil, "message")
if _, ok := e2.(DetailedError); !ok {
t.Fatalf("autorest: NewErrorWithError failed to wrap a standard error -- received %T", e2)
}
}
func TestDetailedError(t *testing.T) {
err := fmt.Errorf("original")
e := NewErrorWithError(err, "packageType", "method", nil, "message")
if matched, _ := regexp.MatchString(`.*original.*`, e.Error()); !matched {
t.Fatalf("autorest: Error#Error failed to return original error message -- expected %v, received %v",
`.*original.*`, e.Error())
}
}
func TestDetailedErrorConstainsPackageType(t *testing.T) {
e := NewErrorWithError(fmt.Errorf("original"), "packageType", "method", nil, "message")
if matched, _ := regexp.MatchString(`.*packageType.*`, e.Error()); !matched {
t.Fatalf("autorest: Error#String failed to include PackageType -- expected %v, received %v",
`.*packageType.*`, e.Error())
}
}
func TestDetailedErrorConstainsMethod(t *testing.T) {
e := NewErrorWithError(fmt.Errorf("original"), "packageType", "method", nil, "message")
if matched, _ := regexp.MatchString(`.*method.*`, e.Error()); !matched {
t.Fatalf("autorest: Error#String failed to include Method -- expected %v, received %v",
`.*method.*`, e.Error())
}
}
func TestDetailedErrorConstainsMessage(t *testing.T) {
e := NewErrorWithError(fmt.Errorf("original"), "packageType", "method", nil, "message")
if matched, _ := regexp.MatchString(`.*message.*`, e.Error()); !matched {
t.Fatalf("autorest: Error#String failed to include Message -- expected %v, received %v",
`.*message.*`, e.Error())
}
}
func TestDetailedErrorConstainsStatusCode(t *testing.T) {
e := NewErrorWithError(fmt.Errorf("original"), "packageType", "method", &http.Response{
StatusCode: http.StatusBadRequest,
Status: http.StatusText(http.StatusBadRequest)}, "message")
if matched, _ := regexp.MatchString(`.*400.*`, e.Error()); !matched {
t.Fatalf("autorest: Error#String failed to include Status Code -- expected %v, received %v",
`.*400.*`, e.Error())
}
}
func TestDetailedErrorConstainsOriginal(t *testing.T) {
e := NewErrorWithError(fmt.Errorf("original"), "packageType", "method", nil, "message")
if matched, _ := regexp.MatchString(`.*original.*`, e.Error()); !matched {
t.Fatalf("autorest: Error#String failed to include Original error -- expected %v, received %v",
`.*original.*`, e.Error())
}
}
func TestDetailedErrorSkipsOriginal(t *testing.T) {
e := NewError("packageType", "method", "message")
if matched, _ := regexp.MatchString(`.*Original.*`, e.Error()); matched {
t.Fatalf("autorest: Error#String included missing Original error -- unexpected %v, received %v",
`.*Original.*`, e.Error())
}
}

View File

@@ -1,151 +0,0 @@
package mocks
// Copyright 2017 Microsoft Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
import (
"fmt"
"net/http"
"time"
)
const (
// TestAuthorizationHeader is a faux HTTP Authorization header value
TestAuthorizationHeader = "BEARER SECRETTOKEN"
// TestBadURL is a malformed URL
TestBadURL = " "
// TestDelay is the Retry-After delay used in tests.
TestDelay = 0 * time.Second
// TestHeader is the header used in tests.
TestHeader = "x-test-header"
// TestURL is the URL used in tests.
TestURL = "https://microsoft.com/a/b/c/"
// TestAzureAsyncURL is a URL used in Azure asynchronous tests
TestAzureAsyncURL = "https://microsoft.com/a/b/c/async"
// TestLocationURL is a URL used in Azure asynchronous tests
TestLocationURL = "https://microsoft.com/a/b/c/location"
)
const (
headerLocation = "Location"
headerRetryAfter = "Retry-After"
)
// NewRequest instantiates a new request.
func NewRequest() *http.Request {
return NewRequestWithContent("")
}
// NewRequestWithContent instantiates a new request using the passed string for the body content.
func NewRequestWithContent(c string) *http.Request {
r, _ := http.NewRequest("GET", "https://microsoft.com/a/b/c/", NewBody(c))
return r
}
// NewRequestWithCloseBody instantiates a new request.
func NewRequestWithCloseBody() *http.Request {
return NewRequestWithCloseBodyContent("request body")
}
// NewRequestWithCloseBodyContent instantiates a new request using the passed string for the body content.
func NewRequestWithCloseBodyContent(c string) *http.Request {
r, _ := http.NewRequest("GET", "https://microsoft.com/a/b/c/", NewBodyClose(c))
return r
}
// NewRequestForURL instantiates a new request using the passed URL.
func NewRequestForURL(u string) *http.Request {
r, err := http.NewRequest("GET", u, NewBody(""))
if err != nil {
panic(fmt.Sprintf("mocks: ERROR (%v) parsing testing URL %s", err, u))
}
return r
}
// NewResponse instantiates a new response.
func NewResponse() *http.Response {
return NewResponseWithContent("")
}
// NewResponseWithContent instantiates a new response with the passed string as the body content.
func NewResponseWithContent(c string) *http.Response {
return &http.Response{
Status: "200 OK",
StatusCode: 200,
Proto: "HTTP/1.0",
ProtoMajor: 1,
ProtoMinor: 0,
Body: NewBody(c),
Request: NewRequest(),
}
}
// NewResponseWithStatus instantiates a new response using the passed string and integer as the
// status and status code.
func NewResponseWithStatus(s string, c int) *http.Response {
resp := NewResponse()
resp.Status = s
resp.StatusCode = c
return resp
}
// NewResponseWithBodyAndStatus instantiates a new response using the specified mock body,
// status and status code
func NewResponseWithBodyAndStatus(body *Body, c int, s string) *http.Response {
resp := NewResponse()
resp.Body = body
resp.Status = s
resp.StatusCode = c
return resp
}
// SetResponseHeader adds a header to the passed response.
func SetResponseHeader(resp *http.Response, h string, v string) {
if resp.Header == nil {
resp.Header = make(http.Header)
}
resp.Header.Set(h, v)
}
// SetResponseHeaderValues adds a header containing all the passed string values.
func SetResponseHeaderValues(resp *http.Response, h string, values []string) {
if resp.Header == nil {
resp.Header = make(http.Header)
}
for _, v := range values {
resp.Header.Add(h, v)
}
}
// SetAcceptedHeaders adds the headers usually associated with a 202 Accepted response.
func SetAcceptedHeaders(resp *http.Response) {
SetLocationHeader(resp, TestURL)
SetRetryHeader(resp, TestDelay)
}
// SetLocationHeader adds the Location header.
func SetLocationHeader(resp *http.Response, location string) {
SetResponseHeader(resp, http.CanonicalHeaderKey(headerLocation), location)
}
// SetRetryHeader adds the Retry-After header.
func SetRetryHeader(resp *http.Response, delay time.Duration) {
SetResponseHeader(resp, http.CanonicalHeaderKey(headerRetryAfter), fmt.Sprintf("%v", delay.Seconds()))
}

View File

@@ -1,219 +0,0 @@
/*
Package mocks provides mocks and helpers used in testing.
*/
package mocks
// Copyright 2017 Microsoft Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
import (
"fmt"
"io"
"net/http"
"time"
)
// Body implements acceptable body over a string.
type Body struct {
s string
b []byte
isOpen bool
closeAttempts int
}
// NewBody creates a new instance of Body.
func NewBody(s string) *Body {
return (&Body{s: s}).reset()
}
// NewBodyClose creates a new instance of Body.
func NewBodyClose(s string) *Body {
return &Body{s: s}
}
// Read reads into the passed byte slice and returns the bytes read.
func (body *Body) Read(b []byte) (n int, err error) {
if !body.IsOpen() {
return 0, fmt.Errorf("ERROR: Body has been closed")
}
if len(body.b) == 0 {
return 0, io.EOF
}
n = copy(b, body.b)
body.b = body.b[n:]
return n, nil
}
// Close closes the body.
func (body *Body) Close() error {
if body.isOpen {
body.isOpen = false
body.closeAttempts++
}
return nil
}
// CloseAttempts returns the number of times Close was called.
func (body *Body) CloseAttempts() int {
return body.closeAttempts
}
// IsOpen returns true if the Body has not been closed, false otherwise.
func (body *Body) IsOpen() bool {
return body.isOpen
}
func (body *Body) reset() *Body {
body.isOpen = true
body.b = []byte(body.s)
return body
}
type response struct {
r *http.Response
e error
d time.Duration
}
// Sender implements a simple null sender.
type Sender struct {
attempts int
responses []response
numResponses int
repeatResponse []int
err error
repeatError int
emitErrorAfter int
}
// NewSender creates a new instance of Sender.
func NewSender() *Sender {
return &Sender{}
}
// Do accepts the passed request and, based on settings, emits a response and possible error.
func (c *Sender) Do(r *http.Request) (resp *http.Response, err error) {
c.attempts++
if len(c.responses) > 0 {
resp = c.responses[0].r
if resp != nil {
if b, ok := resp.Body.(*Body); ok {
b.reset()
}
} else {
err = c.responses[0].e
}
time.Sleep(c.responses[0].d)
c.repeatResponse[0]--
if c.repeatResponse[0] == 0 {
c.responses = c.responses[1:]
c.repeatResponse = c.repeatResponse[1:]
}
} else {
resp = NewResponse()
}
if resp != nil {
resp.Request = r
}
if c.emitErrorAfter > 0 {
c.emitErrorAfter--
} else if c.err != nil {
err = c.err
c.repeatError--
if c.repeatError == 0 {
c.err = nil
}
}
return
}
// AppendResponse adds the passed http.Response to the response stack.
func (c *Sender) AppendResponse(resp *http.Response) {
c.AppendAndRepeatResponse(resp, 1)
}
// AppendResponseWithDelay adds the passed http.Response to the response stack with the specified delay.
func (c *Sender) AppendResponseWithDelay(resp *http.Response, delay time.Duration) {
c.AppendAndRepeatResponseWithDelay(resp, delay, 1)
}
// AppendAndRepeatResponse adds the passed http.Response to the response stack along with a
// repeat count. A negative repeat count will return the response for all remaining calls to Do.
func (c *Sender) AppendAndRepeatResponse(resp *http.Response, repeat int) {
c.appendAndRepeat(response{r: resp}, repeat)
}
// AppendAndRepeatResponseWithDelay adds the passed http.Response to the response stack with the specified
// delay along with a repeat count. A negative repeat count will return the response for all remaining calls to Do.
func (c *Sender) AppendAndRepeatResponseWithDelay(resp *http.Response, delay time.Duration, repeat int) {
c.appendAndRepeat(response{r: resp, d: delay}, repeat)
}
// AppendError adds the passed error to the response stack.
func (c *Sender) AppendError(err error) {
c.AppendAndRepeatError(err, 1)
}
// AppendAndRepeatError adds the passed error to the response stack along with a repeat
// count. A negative repeat count will return the response for all remaining calls to Do.
func (c *Sender) AppendAndRepeatError(err error, repeat int) {
c.appendAndRepeat(response{e: err}, repeat)
}
func (c *Sender) appendAndRepeat(resp response, repeat int) {
if c.responses == nil {
c.responses = []response{resp}
c.repeatResponse = []int{repeat}
} else {
c.responses = append(c.responses, resp)
c.repeatResponse = append(c.repeatResponse, repeat)
}
c.numResponses++
}
// Attempts returns the number of times Do was called.
func (c *Sender) Attempts() int {
return c.attempts
}
// SetError sets the error Do should return.
func (c *Sender) SetError(err error) {
c.SetAndRepeatError(err, 1)
}
// SetAndRepeatError sets the error Do should return and how many calls to Do will return the error.
// A negative repeat value will return the error for all remaining calls to Do.
func (c *Sender) SetAndRepeatError(err error, repeat int) {
c.err = err
c.repeatError = repeat
}
// SetEmitErrorAfter sets the number of attempts to be made before errors are emitted.
func (c *Sender) SetEmitErrorAfter(ea int) {
c.emitErrorAfter = ea
}
// NumResponses returns the number of responses that have been added to the sender.
func (c *Sender) NumResponses() int {
return c.numResponses
}
// T is a simple testing struct.
type T struct {
Name string `json:"name" xml:"Name"`
Age int `json:"age" xml:"Age"`
}

View File

@@ -1,442 +0,0 @@
package autorest
// Copyright 2017 Microsoft Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
import (
"bytes"
"encoding/json"
"fmt"
"io"
"io/ioutil"
"mime/multipart"
"net/http"
"net/url"
"strings"
)
const (
mimeTypeJSON = "application/json"
mimeTypeFormPost = "application/x-www-form-urlencoded"
headerAuthorization = "Authorization"
headerContentType = "Content-Type"
headerUserAgent = "User-Agent"
)
// Preparer is the interface that wraps the Prepare method.
//
// Prepare accepts and possibly modifies an http.Request (e.g., adding Headers). Implementations
// must ensure to not share or hold per-invocation state since Preparers may be shared and re-used.
type Preparer interface {
Prepare(*http.Request) (*http.Request, error)
}
// PreparerFunc is a method that implements the Preparer interface.
type PreparerFunc func(*http.Request) (*http.Request, error)
// Prepare implements the Preparer interface on PreparerFunc.
func (pf PreparerFunc) Prepare(r *http.Request) (*http.Request, error) {
return pf(r)
}
// PrepareDecorator takes and possibly decorates, by wrapping, a Preparer. Decorators may affect the
// http.Request and pass it along or, first, pass the http.Request along then affect the result.
type PrepareDecorator func(Preparer) Preparer
// CreatePreparer creates, decorates, and returns a Preparer.
// Without decorators, the returned Preparer returns the passed http.Request unmodified.
// Preparers are safe to share and re-use.
func CreatePreparer(decorators ...PrepareDecorator) Preparer {
return DecoratePreparer(
Preparer(PreparerFunc(func(r *http.Request) (*http.Request, error) { return r, nil })),
decorators...)
}
// DecoratePreparer accepts a Preparer and a, possibly empty, set of PrepareDecorators, which it
// applies to the Preparer. Decorators are applied in the order received, but their affect upon the
// request depends on whether they are a pre-decorator (change the http.Request and then pass it
// along) or a post-decorator (pass the http.Request along and alter it on return).
func DecoratePreparer(p Preparer, decorators ...PrepareDecorator) Preparer {
for _, decorate := range decorators {
p = decorate(p)
}
return p
}
// Prepare accepts an http.Request and a, possibly empty, set of PrepareDecorators.
// It creates a Preparer from the decorators which it then applies to the passed http.Request.
func Prepare(r *http.Request, decorators ...PrepareDecorator) (*http.Request, error) {
if r == nil {
return nil, NewError("autorest", "Prepare", "Invoked without an http.Request")
}
return CreatePreparer(decorators...).Prepare(r)
}
// WithNothing returns a "do nothing" PrepareDecorator that makes no changes to the passed
// http.Request.
func WithNothing() PrepareDecorator {
return func(p Preparer) Preparer {
return PreparerFunc(func(r *http.Request) (*http.Request, error) {
return p.Prepare(r)
})
}
}
// WithHeader returns a PrepareDecorator that sets the specified HTTP header of the http.Request to
// the passed value. It canonicalizes the passed header name (via http.CanonicalHeaderKey) before
// adding the header.
func WithHeader(header string, value string) PrepareDecorator {
return func(p Preparer) Preparer {
return PreparerFunc(func(r *http.Request) (*http.Request, error) {
r, err := p.Prepare(r)
if err == nil {
if r.Header == nil {
r.Header = make(http.Header)
}
r.Header.Set(http.CanonicalHeaderKey(header), value)
}
return r, err
})
}
}
// WithBearerAuthorization returns a PrepareDecorator that adds an HTTP Authorization header whose
// value is "Bearer " followed by the supplied token.
func WithBearerAuthorization(token string) PrepareDecorator {
return WithHeader(headerAuthorization, fmt.Sprintf("Bearer %s", token))
}
// AsContentType returns a PrepareDecorator that adds an HTTP Content-Type header whose value
// is the passed contentType.
func AsContentType(contentType string) PrepareDecorator {
return WithHeader(headerContentType, contentType)
}
// WithUserAgent returns a PrepareDecorator that adds an HTTP User-Agent header whose value is the
// passed string.
func WithUserAgent(ua string) PrepareDecorator {
return WithHeader(headerUserAgent, ua)
}
// AsFormURLEncoded returns a PrepareDecorator that adds an HTTP Content-Type header whose value is
// "application/x-www-form-urlencoded".
func AsFormURLEncoded() PrepareDecorator {
return AsContentType(mimeTypeFormPost)
}
// AsJSON returns a PrepareDecorator that adds an HTTP Content-Type header whose value is
// "application/json".
func AsJSON() PrepareDecorator {
return AsContentType(mimeTypeJSON)
}
// WithMethod returns a PrepareDecorator that sets the HTTP method of the passed request. The
// decorator does not validate that the passed method string is a known HTTP method.
func WithMethod(method string) PrepareDecorator {
return func(p Preparer) Preparer {
return PreparerFunc(func(r *http.Request) (*http.Request, error) {
r.Method = method
return p.Prepare(r)
})
}
}
// AsDelete returns a PrepareDecorator that sets the HTTP method to DELETE.
func AsDelete() PrepareDecorator { return WithMethod("DELETE") }
// AsGet returns a PrepareDecorator that sets the HTTP method to GET.
func AsGet() PrepareDecorator { return WithMethod("GET") }
// AsHead returns a PrepareDecorator that sets the HTTP method to HEAD.
func AsHead() PrepareDecorator { return WithMethod("HEAD") }
// AsOptions returns a PrepareDecorator that sets the HTTP method to OPTIONS.
func AsOptions() PrepareDecorator { return WithMethod("OPTIONS") }
// AsPatch returns a PrepareDecorator that sets the HTTP method to PATCH.
func AsPatch() PrepareDecorator { return WithMethod("PATCH") }
// AsPost returns a PrepareDecorator that sets the HTTP method to POST.
func AsPost() PrepareDecorator { return WithMethod("POST") }
// AsPut returns a PrepareDecorator that sets the HTTP method to PUT.
func AsPut() PrepareDecorator { return WithMethod("PUT") }
// WithBaseURL returns a PrepareDecorator that populates the http.Request with a url.URL constructed
// from the supplied baseUrl.
func WithBaseURL(baseURL string) PrepareDecorator {
return func(p Preparer) Preparer {
return PreparerFunc(func(r *http.Request) (*http.Request, error) {
r, err := p.Prepare(r)
if err == nil {
var u *url.URL
if u, err = url.Parse(baseURL); err != nil {
return r, err
}
if u.Scheme == "" {
err = fmt.Errorf("autorest: No scheme detected in URL %s", baseURL)
}
if err == nil {
r.URL = u
}
}
return r, err
})
}
}
// WithCustomBaseURL returns a PrepareDecorator that replaces brace-enclosed keys within the
// request base URL (i.e., http.Request.URL) with the corresponding values from the passed map.
func WithCustomBaseURL(baseURL string, urlParameters map[string]interface{}) PrepareDecorator {
parameters := ensureValueStrings(urlParameters)
for key, value := range parameters {
baseURL = strings.Replace(baseURL, "{"+key+"}", value, -1)
}
return WithBaseURL(baseURL)
}
// WithFormData returns a PrepareDecoratore that "URL encodes" (e.g., bar=baz&foo=quux) into the
// http.Request body.
func WithFormData(v url.Values) PrepareDecorator {
return func(p Preparer) Preparer {
return PreparerFunc(func(r *http.Request) (*http.Request, error) {
r, err := p.Prepare(r)
if err == nil {
s := v.Encode()
r.ContentLength = int64(len(s))
r.Body = ioutil.NopCloser(strings.NewReader(s))
}
return r, err
})
}
}
// WithMultiPartFormData returns a PrepareDecoratore that "URL encodes" (e.g., bar=baz&foo=quux) form parameters
// into the http.Request body.
func WithMultiPartFormData(formDataParameters map[string]interface{}) PrepareDecorator {
return func(p Preparer) Preparer {
return PreparerFunc(func(r *http.Request) (*http.Request, error) {
r, err := p.Prepare(r)
if err == nil {
var body bytes.Buffer
writer := multipart.NewWriter(&body)
for key, value := range formDataParameters {
if rc, ok := value.(io.ReadCloser); ok {
var fd io.Writer
if fd, err = writer.CreateFormFile(key, key); err != nil {
return r, err
}
if _, err = io.Copy(fd, rc); err != nil {
return r, err
}
} else {
if err = writer.WriteField(key, ensureValueString(value)); err != nil {
return r, err
}
}
}
if err = writer.Close(); err != nil {
return r, err
}
if r.Header == nil {
r.Header = make(http.Header)
}
r.Header.Set(http.CanonicalHeaderKey(headerContentType), writer.FormDataContentType())
r.Body = ioutil.NopCloser(bytes.NewReader(body.Bytes()))
r.ContentLength = int64(body.Len())
return r, err
}
return r, err
})
}
}
// WithFile returns a PrepareDecorator that sends file in request body.
func WithFile(f io.ReadCloser) PrepareDecorator {
return func(p Preparer) Preparer {
return PreparerFunc(func(r *http.Request) (*http.Request, error) {
r, err := p.Prepare(r)
if err == nil {
b, err := ioutil.ReadAll(f)
if err != nil {
return r, err
}
r.Body = ioutil.NopCloser(bytes.NewReader(b))
r.ContentLength = int64(len(b))
}
return r, err
})
}
}
// WithBool returns a PrepareDecorator that encodes the passed bool into the body of the request
// and sets the Content-Length header.
func WithBool(v bool) PrepareDecorator {
return WithString(fmt.Sprintf("%v", v))
}
// WithFloat32 returns a PrepareDecorator that encodes the passed float32 into the body of the
// request and sets the Content-Length header.
func WithFloat32(v float32) PrepareDecorator {
return WithString(fmt.Sprintf("%v", v))
}
// WithFloat64 returns a PrepareDecorator that encodes the passed float64 into the body of the
// request and sets the Content-Length header.
func WithFloat64(v float64) PrepareDecorator {
return WithString(fmt.Sprintf("%v", v))
}
// WithInt32 returns a PrepareDecorator that encodes the passed int32 into the body of the request
// and sets the Content-Length header.
func WithInt32(v int32) PrepareDecorator {
return WithString(fmt.Sprintf("%v", v))
}
// WithInt64 returns a PrepareDecorator that encodes the passed int64 into the body of the request
// and sets the Content-Length header.
func WithInt64(v int64) PrepareDecorator {
return WithString(fmt.Sprintf("%v", v))
}
// WithString returns a PrepareDecorator that encodes the passed string into the body of the request
// and sets the Content-Length header.
func WithString(v string) PrepareDecorator {
return func(p Preparer) Preparer {
return PreparerFunc(func(r *http.Request) (*http.Request, error) {
r, err := p.Prepare(r)
if err == nil {
r.ContentLength = int64(len(v))
r.Body = ioutil.NopCloser(strings.NewReader(v))
}
return r, err
})
}
}
// WithJSON returns a PrepareDecorator that encodes the data passed as JSON into the body of the
// request and sets the Content-Length header.
func WithJSON(v interface{}) PrepareDecorator {
return func(p Preparer) Preparer {
return PreparerFunc(func(r *http.Request) (*http.Request, error) {
r, err := p.Prepare(r)
if err == nil {
b, err := json.Marshal(v)
if err == nil {
r.ContentLength = int64(len(b))
r.Body = ioutil.NopCloser(bytes.NewReader(b))
}
}
return r, err
})
}
}
// WithPath returns a PrepareDecorator that adds the supplied path to the request URL. If the path
// is absolute (that is, it begins with a "/"), it replaces the existing path.
func WithPath(path string) PrepareDecorator {
return func(p Preparer) Preparer {
return PreparerFunc(func(r *http.Request) (*http.Request, error) {
r, err := p.Prepare(r)
if err == nil {
if r.URL == nil {
return r, NewError("autorest", "WithPath", "Invoked with a nil URL")
}
if r.URL, err = parseURL(r.URL, path); err != nil {
return r, err
}
}
return r, err
})
}
}
// WithEscapedPathParameters returns a PrepareDecorator that replaces brace-enclosed keys within the
// request path (i.e., http.Request.URL.Path) with the corresponding values from the passed map. The
// values will be escaped (aka URL encoded) before insertion into the path.
func WithEscapedPathParameters(path string, pathParameters map[string]interface{}) PrepareDecorator {
parameters := escapeValueStrings(ensureValueStrings(pathParameters))
return func(p Preparer) Preparer {
return PreparerFunc(func(r *http.Request) (*http.Request, error) {
r, err := p.Prepare(r)
if err == nil {
if r.URL == nil {
return r, NewError("autorest", "WithEscapedPathParameters", "Invoked with a nil URL")
}
for key, value := range parameters {
path = strings.Replace(path, "{"+key+"}", value, -1)
}
if r.URL, err = parseURL(r.URL, path); err != nil {
return r, err
}
}
return r, err
})
}
}
// WithPathParameters returns a PrepareDecorator that replaces brace-enclosed keys within the
// request path (i.e., http.Request.URL.Path) with the corresponding values from the passed map.
func WithPathParameters(path string, pathParameters map[string]interface{}) PrepareDecorator {
parameters := ensureValueStrings(pathParameters)
return func(p Preparer) Preparer {
return PreparerFunc(func(r *http.Request) (*http.Request, error) {
r, err := p.Prepare(r)
if err == nil {
if r.URL == nil {
return r, NewError("autorest", "WithPathParameters", "Invoked with a nil URL")
}
for key, value := range parameters {
path = strings.Replace(path, "{"+key+"}", value, -1)
}
if r.URL, err = parseURL(r.URL, path); err != nil {
return r, err
}
}
return r, err
})
}
}
func parseURL(u *url.URL, path string) (*url.URL, error) {
p := strings.TrimRight(u.String(), "/")
if !strings.HasPrefix(path, "/") {
path = "/" + path
}
return url.Parse(p + path)
}
// WithQueryParameters returns a PrepareDecorators that encodes and applies the query parameters
// given in the supplied map (i.e., key=value).
func WithQueryParameters(queryParameters map[string]interface{}) PrepareDecorator {
parameters := ensureValueStrings(queryParameters)
return func(p Preparer) Preparer {
return PreparerFunc(func(r *http.Request) (*http.Request, error) {
r, err := p.Prepare(r)
if err == nil {
if r.URL == nil {
return r, NewError("autorest", "WithQueryParameters", "Invoked with a nil URL")
}
v := r.URL.Query()
for key, value := range parameters {
v.Add(key, value)
}
r.URL.RawQuery = createQuery(v)
}
return r, err
})
}
}

View File

@@ -1,766 +0,0 @@
package autorest
// Copyright 2017 Microsoft Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
import (
"fmt"
"io/ioutil"
"net/http"
"net/url"
"reflect"
"strconv"
"strings"
"testing"
"github.com/Azure/go-autorest/autorest/mocks"
)
// PrepareDecorators wrap and invoke a Preparer. Most often, the decorator invokes the passed
// Preparer and decorates the response.
func ExamplePrepareDecorator() {
path := "a/b/c/"
pd := func() PrepareDecorator {
return func(p Preparer) Preparer {
return PreparerFunc(func(r *http.Request) (*http.Request, error) {
r, err := p.Prepare(r)
if err == nil {
if r.URL == nil {
return r, fmt.Errorf("ERROR: URL is not set")
}
r.URL.Path += path
}
return r, err
})
}
}
r, _ := Prepare(&http.Request{},
WithBaseURL("https://microsoft.com/"),
pd())
fmt.Printf("Path is %s\n", r.URL)
// Output: Path is https://microsoft.com/a/b/c/
}
// PrepareDecorators may also modify and then invoke the Preparer.
func ExamplePrepareDecorator_pre() {
pd := func() PrepareDecorator {
return func(p Preparer) Preparer {
return PreparerFunc(func(r *http.Request) (*http.Request, error) {
r.Header.Add(http.CanonicalHeaderKey("ContentType"), "application/json")
return p.Prepare(r)
})
}
}
r, _ := Prepare(&http.Request{Header: http.Header{}},
pd())
fmt.Printf("ContentType is %s\n", r.Header.Get("ContentType"))
// Output: ContentType is application/json
}
// Create a sequence of three Preparers that build up the URL path.
func ExampleCreatePreparer() {
p := CreatePreparer(
WithBaseURL("https://microsoft.com/"),
WithPath("a"),
WithPath("b"),
WithPath("c"))
r, err := p.Prepare(&http.Request{})
if err != nil {
fmt.Printf("ERROR: %v\n", err)
} else {
fmt.Println(r.URL)
}
// Output: https://microsoft.com/a/b/c
}
// Create and apply separate Preparers
func ExampleCreatePreparer_multiple() {
params := map[string]interface{}{
"param1": "a",
"param2": "c",
}
p1 := CreatePreparer(WithBaseURL("https://microsoft.com/"))
p2 := CreatePreparer(WithPathParameters("/{param1}/b/{param2}/", params))
r, err := p1.Prepare(&http.Request{})
if err != nil {
fmt.Printf("ERROR: %v\n", err)
}
r, err = p2.Prepare(r)
if err != nil {
fmt.Printf("ERROR: %v\n", err)
} else {
fmt.Println(r.URL)
}
// Output: https://microsoft.com/a/b/c/
}
// Create and chain separate Preparers
func ExampleCreatePreparer_chain() {
params := map[string]interface{}{
"param1": "a",
"param2": "c",
}
p := CreatePreparer(WithBaseURL("https://microsoft.com/"))
p = DecoratePreparer(p, WithPathParameters("/{param1}/b/{param2}/", params))
r, err := p.Prepare(&http.Request{})
if err != nil {
fmt.Printf("ERROR: %v\n", err)
} else {
fmt.Println(r.URL)
}
// Output: https://microsoft.com/a/b/c/
}
// Create and prepare an http.Request in one call
func ExamplePrepare() {
r, err := Prepare(&http.Request{},
AsGet(),
WithBaseURL("https://microsoft.com/"),
WithPath("a/b/c/"))
if err != nil {
fmt.Printf("ERROR: %v\n", err)
} else {
fmt.Printf("%s %s", r.Method, r.URL)
}
// Output: GET https://microsoft.com/a/b/c/
}
// Create a request for a supplied base URL and path
func ExampleWithBaseURL() {
r, err := Prepare(&http.Request{},
WithBaseURL("https://microsoft.com/a/b/c/"))
if err != nil {
fmt.Printf("ERROR: %v\n", err)
} else {
fmt.Println(r.URL)
}
// Output: https://microsoft.com/a/b/c/
}
func ExampleWithBaseURL_second() {
_, err := Prepare(&http.Request{}, WithBaseURL(":"))
fmt.Println(err)
// Output: parse :: missing protocol scheme
}
func ExampleWithCustomBaseURL() {
r, err := Prepare(&http.Request{},
WithCustomBaseURL("https://{account}.{service}.core.windows.net/",
map[string]interface{}{
"account": "myaccount",
"service": "blob",
}))
if err != nil {
fmt.Printf("ERROR: %v\n", err)
} else {
fmt.Println(r.URL)
}
// Output: https://myaccount.blob.core.windows.net/
}
func ExampleWithCustomBaseURL_second() {
_, err := Prepare(&http.Request{},
WithCustomBaseURL(":", map[string]interface{}{}))
fmt.Println(err)
// Output: parse :: missing protocol scheme
}
// Create a request with a custom HTTP header
func ExampleWithHeader() {
r, err := Prepare(&http.Request{},
WithBaseURL("https://microsoft.com/a/b/c/"),
WithHeader("x-foo", "bar"))
if err != nil {
fmt.Printf("ERROR: %v\n", err)
} else {
fmt.Printf("Header %s=%s\n", "x-foo", r.Header.Get("x-foo"))
}
// Output: Header x-foo=bar
}
// Create a request whose Body is the JSON encoding of a structure
func ExampleWithFormData() {
v := url.Values{}
v.Add("name", "Rob Pike")
v.Add("age", "42")
r, err := Prepare(&http.Request{},
WithFormData(v))
if err != nil {
fmt.Printf("ERROR: %v\n", err)
}
b, err := ioutil.ReadAll(r.Body)
if err != nil {
fmt.Printf("ERROR: %v\n", err)
} else {
fmt.Printf("Request Body contains %s\n", string(b))
}
// Output: Request Body contains age=42&name=Rob+Pike
}
// Create a request whose Body is the JSON encoding of a structure
func ExampleWithJSON() {
t := mocks.T{Name: "Rob Pike", Age: 42}
r, err := Prepare(&http.Request{},
WithJSON(&t))
if err != nil {
fmt.Printf("ERROR: %v\n", err)
}
b, err := ioutil.ReadAll(r.Body)
if err != nil {
fmt.Printf("ERROR: %v\n", err)
} else {
fmt.Printf("Request Body contains %s\n", string(b))
}
// Output: Request Body contains {"name":"Rob Pike","age":42}
}
// Create a request from a path with escaped parameters
func ExampleWithEscapedPathParameters() {
params := map[string]interface{}{
"param1": "a b c",
"param2": "d e f",
}
r, err := Prepare(&http.Request{},
WithBaseURL("https://microsoft.com/"),
WithEscapedPathParameters("/{param1}/b/{param2}/", params))
if err != nil {
fmt.Printf("ERROR: %v\n", err)
} else {
fmt.Println(r.URL)
}
// Output: https://microsoft.com/a+b+c/b/d+e+f/
}
// Create a request from a path with parameters
func ExampleWithPathParameters() {
params := map[string]interface{}{
"param1": "a",
"param2": "c",
}
r, err := Prepare(&http.Request{},
WithBaseURL("https://microsoft.com/"),
WithPathParameters("/{param1}/b/{param2}/", params))
if err != nil {
fmt.Printf("ERROR: %v\n", err)
} else {
fmt.Println(r.URL)
}
// Output: https://microsoft.com/a/b/c/
}
// Create a request with query parameters
func ExampleWithQueryParameters() {
params := map[string]interface{}{
"q1": "value1",
"q2": "value2",
}
r, err := Prepare(&http.Request{},
WithBaseURL("https://microsoft.com/"),
WithPath("/a/b/c/"),
WithQueryParameters(params))
if err != nil {
fmt.Printf("ERROR: %v\n", err)
} else {
fmt.Println(r.URL)
}
// Output: https://microsoft.com/a/b/c/?q1=value1&q2=value2
}
func TestWithCustomBaseURL(t *testing.T) {
r, err := Prepare(&http.Request{}, WithCustomBaseURL("https://{account}.{service}.core.windows.net/",
map[string]interface{}{
"account": "myaccount",
"service": "blob",
}))
if err != nil {
t.Fatalf("autorest: WithCustomBaseURL should not fail")
}
if r.URL.String() != "https://myaccount.blob.core.windows.net/" {
t.Fatalf("autorest: WithCustomBaseURL expected https://myaccount.blob.core.windows.net/, got %s", r.URL)
}
}
func TestWithCustomBaseURLwithInvalidURL(t *testing.T) {
_, err := Prepare(&http.Request{}, WithCustomBaseURL("hello/{account}.{service}.core.windows.net/",
map[string]interface{}{
"account": "myaccount",
"service": "blob",
}))
if err == nil {
t.Fatalf("autorest: WithCustomBaseURL should fail fo URL parse error")
}
}
func TestWithPathWithInvalidPath(t *testing.T) {
p := "path%2*end"
if _, err := Prepare(&http.Request{}, WithBaseURL("https://microsoft.com/"), WithPath(p)); err == nil {
t.Fatalf("autorest: WithPath should fail for invalid URL escape error for path '%v' ", p)
}
}
func TestWithPathParametersWithInvalidPath(t *testing.T) {
p := "path%2*end"
m := map[string]interface{}{
"path1": p,
}
if _, err := Prepare(&http.Request{}, WithBaseURL("https://microsoft.com/"), WithPathParameters("/{path1}/", m)); err == nil {
t.Fatalf("autorest: WithPath should fail for invalid URL escape for path '%v' ", p)
}
}
func TestCreatePreparerDoesNotModify(t *testing.T) {
r1 := &http.Request{}
p := CreatePreparer()
r2, err := p.Prepare(r1)
if err != nil {
t.Fatalf("autorest: CreatePreparer failed (%v)", err)
}
if !reflect.DeepEqual(r1, r2) {
t.Fatalf("autorest: CreatePreparer without decorators modified the request")
}
}
func TestCreatePreparerRunsDecoratorsInOrder(t *testing.T) {
p := CreatePreparer(WithBaseURL("https://microsoft.com/"), WithPath("1"), WithPath("2"), WithPath("3"))
r, err := p.Prepare(&http.Request{})
if err != nil {
t.Fatalf("autorest: CreatePreparer failed (%v)", err)
}
if r.URL.String() != "https:/1/2/3" && r.URL.Host != "microsoft.com" {
t.Fatalf("autorest: CreatePreparer failed to run decorators in order")
}
}
func TestAsContentType(t *testing.T) {
r, err := Prepare(mocks.NewRequest(), AsContentType("application/text"))
if err != nil {
fmt.Printf("ERROR: %v", err)
}
if r.Header.Get(headerContentType) != "application/text" {
t.Fatalf("autorest: AsContentType failed to add header (%s=%s)", headerContentType, r.Header.Get(headerContentType))
}
}
func TestAsFormURLEncoded(t *testing.T) {
r, err := Prepare(mocks.NewRequest(), AsFormURLEncoded())
if err != nil {
fmt.Printf("ERROR: %v", err)
}
if r.Header.Get(headerContentType) != mimeTypeFormPost {
t.Fatalf("autorest: AsFormURLEncoded failed to add header (%s=%s)", headerContentType, r.Header.Get(headerContentType))
}
}
func TestAsJSON(t *testing.T) {
r, err := Prepare(mocks.NewRequest(), AsJSON())
if err != nil {
fmt.Printf("ERROR: %v", err)
}
if r.Header.Get(headerContentType) != mimeTypeJSON {
t.Fatalf("autorest: AsJSON failed to add header (%s=%s)", headerContentType, r.Header.Get(headerContentType))
}
}
func TestWithNothing(t *testing.T) {
r1 := mocks.NewRequest()
r2, err := Prepare(r1, WithNothing())
if err != nil {
t.Fatalf("autorest: WithNothing returned an unexpected error (%v)", err)
}
if !reflect.DeepEqual(r1, r2) {
t.Fatal("azure: WithNothing modified the passed HTTP Request")
}
}
func TestWithBearerAuthorization(t *testing.T) {
r, err := Prepare(mocks.NewRequest(), WithBearerAuthorization("SOME-TOKEN"))
if err != nil {
fmt.Printf("ERROR: %v", err)
}
if r.Header.Get(headerAuthorization) != "Bearer SOME-TOKEN" {
t.Fatalf("autorest: WithBearerAuthorization failed to add header (%s=%s)", headerAuthorization, r.Header.Get(headerAuthorization))
}
}
func TestWithUserAgent(t *testing.T) {
ua := "User Agent Go"
r, err := Prepare(mocks.NewRequest(), WithUserAgent(ua))
if err != nil {
fmt.Printf("ERROR: %v", err)
}
if r.UserAgent() != ua || r.Header.Get(headerUserAgent) != ua {
t.Fatalf("autorest: WithUserAgent failed to add header (%s=%s)", headerUserAgent, r.Header.Get(headerUserAgent))
}
}
func TestWithMethod(t *testing.T) {
r, _ := Prepare(mocks.NewRequest(), WithMethod("HEAD"))
if r.Method != "HEAD" {
t.Fatal("autorest: WithMethod failed to set HTTP method header")
}
}
func TestAsDelete(t *testing.T) {
r, _ := Prepare(mocks.NewRequest(), AsDelete())
if r.Method != "DELETE" {
t.Fatal("autorest: AsDelete failed to set HTTP method header to DELETE")
}
}
func TestAsGet(t *testing.T) {
r, _ := Prepare(mocks.NewRequest(), AsGet())
if r.Method != "GET" {
t.Fatal("autorest: AsGet failed to set HTTP method header to GET")
}
}
func TestAsHead(t *testing.T) {
r, _ := Prepare(mocks.NewRequest(), AsHead())
if r.Method != "HEAD" {
t.Fatal("autorest: AsHead failed to set HTTP method header to HEAD")
}
}
func TestAsOptions(t *testing.T) {
r, _ := Prepare(mocks.NewRequest(), AsOptions())
if r.Method != "OPTIONS" {
t.Fatal("autorest: AsOptions failed to set HTTP method header to OPTIONS")
}
}
func TestAsPatch(t *testing.T) {
r, _ := Prepare(mocks.NewRequest(), AsPatch())
if r.Method != "PATCH" {
t.Fatal("autorest: AsPatch failed to set HTTP method header to PATCH")
}
}
func TestAsPost(t *testing.T) {
r, _ := Prepare(mocks.NewRequest(), AsPost())
if r.Method != "POST" {
t.Fatal("autorest: AsPost failed to set HTTP method header to POST")
}
}
func TestAsPut(t *testing.T) {
r, _ := Prepare(mocks.NewRequest(), AsPut())
if r.Method != "PUT" {
t.Fatal("autorest: AsPut failed to set HTTP method header to PUT")
}
}
func TestPrepareWithNullRequest(t *testing.T) {
_, err := Prepare(nil)
if err == nil {
t.Fatal("autorest: Prepare failed to return an error when given a null http.Request")
}
}
func TestWithFormDataSetsContentLength(t *testing.T) {
v := url.Values{}
v.Add("name", "Rob Pike")
v.Add("age", "42")
r, err := Prepare(&http.Request{},
WithFormData(v))
if err != nil {
t.Fatalf("autorest: WithFormData failed with error (%v)", err)
}
b, err := ioutil.ReadAll(r.Body)
if err != nil {
t.Fatalf("autorest: WithFormData failed with error (%v)", err)
}
expected := "name=Rob+Pike&age=42"
if !(string(b) == "name=Rob+Pike&age=42" || string(b) == "age=42&name=Rob+Pike") {
t.Fatalf("autorest:WithFormData failed to return correct string got (%v), expected (%v)", string(b), expected)
}
if r.ContentLength != int64(len(b)) {
t.Fatalf("autorest:WithFormData set Content-Length to %v, expected %v", r.ContentLength, len(b))
}
}
func TestWithMultiPartFormDataSetsContentLength(t *testing.T) {
v := map[string]interface{}{
"file": ioutil.NopCloser(strings.NewReader("Hello Gopher")),
"age": "42",
}
r, err := Prepare(&http.Request{},
WithMultiPartFormData(v))
if err != nil {
t.Fatalf("autorest: WithMultiPartFormData failed with error (%v)", err)
}
b, err := ioutil.ReadAll(r.Body)
if err != nil {
t.Fatalf("autorest: WithMultiPartFormData failed with error (%v)", err)
}
if r.ContentLength != int64(len(b)) {
t.Fatalf("autorest:WithMultiPartFormData set Content-Length to %v, expected %v", r.ContentLength, len(b))
}
}
func TestWithMultiPartFormDataWithNoFile(t *testing.T) {
v := map[string]interface{}{
"file": "no file",
"age": "42",
}
r, err := Prepare(&http.Request{},
WithMultiPartFormData(v))
if err != nil {
t.Fatalf("autorest: WithMultiPartFormData failed with error (%v)", err)
}
b, err := ioutil.ReadAll(r.Body)
if err != nil {
t.Fatalf("autorest: WithMultiPartFormData failed with error (%v)", err)
}
if r.ContentLength != int64(len(b)) {
t.Fatalf("autorest:WithMultiPartFormData set Content-Length to %v, expected %v", r.ContentLength, len(b))
}
}
func TestWithFile(t *testing.T) {
r, err := Prepare(&http.Request{},
WithFile(ioutil.NopCloser(strings.NewReader("Hello Gopher"))))
if err != nil {
t.Fatalf("autorest: WithFile failed with error (%v)", err)
}
b, err := ioutil.ReadAll(r.Body)
if err != nil {
t.Fatalf("autorest: WithFile failed with error (%v)", err)
}
if r.ContentLength != int64(len(b)) {
t.Fatalf("autorest:WithFile set Content-Length to %v, expected %v", r.ContentLength, len(b))
}
}
func TestWithBool_SetsTheBody(t *testing.T) {
r, err := Prepare(&http.Request{},
WithBool(false))
if err != nil {
t.Fatalf("autorest: WithBool failed with error (%v)", err)
}
s, err := ioutil.ReadAll(r.Body)
if err != nil {
t.Fatalf("autorest: WithBool failed with error (%v)", err)
}
if r.ContentLength != int64(len(fmt.Sprintf("%v", false))) {
t.Fatalf("autorest: WithBool set Content-Length to %v, expected %v", r.ContentLength, int64(len(fmt.Sprintf("%v", false))))
}
v, err := strconv.ParseBool(string(s))
if err != nil || v {
t.Fatalf("autorest: WithBool incorrectly encoded the boolean as %v", s)
}
}
func TestWithFloat32_SetsTheBody(t *testing.T) {
r, err := Prepare(&http.Request{},
WithFloat32(42.0))
if err != nil {
t.Fatalf("autorest: WithFloat32 failed with error (%v)", err)
}
s, err := ioutil.ReadAll(r.Body)
if err != nil {
t.Fatalf("autorest: WithFloat32 failed with error (%v)", err)
}
if r.ContentLength != int64(len(fmt.Sprintf("%v", 42.0))) {
t.Fatalf("autorest: WithFloat32 set Content-Length to %v, expected %v", r.ContentLength, int64(len(fmt.Sprintf("%v", 42.0))))
}
v, err := strconv.ParseFloat(string(s), 32)
if err != nil || float32(v) != float32(42.0) {
t.Fatalf("autorest: WithFloat32 incorrectly encoded the boolean as %v", s)
}
}
func TestWithFloat64_SetsTheBody(t *testing.T) {
r, err := Prepare(&http.Request{},
WithFloat64(42.0))
if err != nil {
t.Fatalf("autorest: WithFloat64 failed with error (%v)", err)
}
s, err := ioutil.ReadAll(r.Body)
if err != nil {
t.Fatalf("autorest: WithFloat64 failed with error (%v)", err)
}
if r.ContentLength != int64(len(fmt.Sprintf("%v", 42.0))) {
t.Fatalf("autorest: WithFloat64 set Content-Length to %v, expected %v", r.ContentLength, int64(len(fmt.Sprintf("%v", 42.0))))
}
v, err := strconv.ParseFloat(string(s), 64)
if err != nil || v != float64(42.0) {
t.Fatalf("autorest: WithFloat64 incorrectly encoded the boolean as %v", s)
}
}
func TestWithInt32_SetsTheBody(t *testing.T) {
r, err := Prepare(&http.Request{},
WithInt32(42))
if err != nil {
t.Fatalf("autorest: WithInt32 failed with error (%v)", err)
}
s, err := ioutil.ReadAll(r.Body)
if err != nil {
t.Fatalf("autorest: WithInt32 failed with error (%v)", err)
}
if r.ContentLength != int64(len(fmt.Sprintf("%v", 42))) {
t.Fatalf("autorest: WithInt32 set Content-Length to %v, expected %v", r.ContentLength, int64(len(fmt.Sprintf("%v", 42))))
}
v, err := strconv.ParseInt(string(s), 10, 32)
if err != nil || int32(v) != int32(42) {
t.Fatalf("autorest: WithInt32 incorrectly encoded the boolean as %v", s)
}
}
func TestWithInt64_SetsTheBody(t *testing.T) {
r, err := Prepare(&http.Request{},
WithInt64(42))
if err != nil {
t.Fatalf("autorest: WithInt64 failed with error (%v)", err)
}
s, err := ioutil.ReadAll(r.Body)
if err != nil {
t.Fatalf("autorest: WithInt64 failed with error (%v)", err)
}
if r.ContentLength != int64(len(fmt.Sprintf("%v", 42))) {
t.Fatalf("autorest: WithInt64 set Content-Length to %v, expected %v", r.ContentLength, int64(len(fmt.Sprintf("%v", 42))))
}
v, err := strconv.ParseInt(string(s), 10, 64)
if err != nil || v != int64(42) {
t.Fatalf("autorest: WithInt64 incorrectly encoded the boolean as %v", s)
}
}
func TestWithString_SetsTheBody(t *testing.T) {
r, err := Prepare(&http.Request{},
WithString("value"))
if err != nil {
t.Fatalf("autorest: WithString failed with error (%v)", err)
}
s, err := ioutil.ReadAll(r.Body)
if err != nil {
t.Fatalf("autorest: WithString failed with error (%v)", err)
}
if r.ContentLength != int64(len("value")) {
t.Fatalf("autorest: WithString set Content-Length to %v, expected %v", r.ContentLength, int64(len("value")))
}
if string(s) != "value" {
t.Fatalf("autorest: WithString incorrectly encoded the string as %v", s)
}
}
func TestWithJSONSetsContentLength(t *testing.T) {
r, err := Prepare(&http.Request{},
WithJSON(&mocks.T{Name: "Rob Pike", Age: 42}))
if err != nil {
t.Fatalf("autorest: WithJSON failed with error (%v)", err)
}
b, err := ioutil.ReadAll(r.Body)
if err != nil {
t.Fatalf("autorest: WithJSON failed with error (%v)", err)
}
if r.ContentLength != int64(len(b)) {
t.Fatalf("autorest:WithJSON set Content-Length to %v, expected %v", r.ContentLength, len(b))
}
}
func TestWithHeaderAllocatesHeaders(t *testing.T) {
r, err := Prepare(mocks.NewRequest(), WithHeader("x-foo", "bar"))
if err != nil {
t.Fatalf("autorest: WithHeader failed (%v)", err)
}
if r.Header.Get("x-foo") != "bar" {
t.Fatalf("autorest: WithHeader failed to add header (%s=%s)", "x-foo", r.Header.Get("x-foo"))
}
}
func TestWithPathCatchesNilURL(t *testing.T) {
_, err := Prepare(&http.Request{}, WithPath("a"))
if err == nil {
t.Fatalf("autorest: WithPath failed to catch a nil URL")
}
}
func TestWithEscapedPathParametersCatchesNilURL(t *testing.T) {
_, err := Prepare(&http.Request{}, WithEscapedPathParameters("", map[string]interface{}{"foo": "bar"}))
if err == nil {
t.Fatalf("autorest: WithEscapedPathParameters failed to catch a nil URL")
}
}
func TestWithPathParametersCatchesNilURL(t *testing.T) {
_, err := Prepare(&http.Request{}, WithPathParameters("", map[string]interface{}{"foo": "bar"}))
if err == nil {
t.Fatalf("autorest: WithPathParameters failed to catch a nil URL")
}
}
func TestWithQueryParametersCatchesNilURL(t *testing.T) {
_, err := Prepare(&http.Request{}, WithQueryParameters(map[string]interface{}{"foo": "bar"}))
if err == nil {
t.Fatalf("autorest: WithQueryParameters failed to catch a nil URL")
}
}
func TestModifyingExistingRequest(t *testing.T) {
r, err := Prepare(mocks.NewRequestForURL("https://bing.com"), WithPath("search"), WithQueryParameters(map[string]interface{}{"q": "golang"}))
if err != nil {
t.Fatalf("autorest: Preparing an existing request returned an error (%v)", err)
}
if r.URL.String() != "https:/search?q=golang" && r.URL.Host != "bing.com" {
t.Fatalf("autorest: Preparing an existing request failed (%s)", r.URL)
}
}

View File

@@ -1,250 +0,0 @@
package autorest
// Copyright 2017 Microsoft Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
import (
"bytes"
"encoding/json"
"encoding/xml"
"fmt"
"io"
"io/ioutil"
"net/http"
"strings"
)
// Responder is the interface that wraps the Respond method.
//
// Respond accepts and reacts to an http.Response. Implementations must ensure to not share or hold
// state since Responders may be shared and re-used.
type Responder interface {
Respond(*http.Response) error
}
// ResponderFunc is a method that implements the Responder interface.
type ResponderFunc func(*http.Response) error
// Respond implements the Responder interface on ResponderFunc.
func (rf ResponderFunc) Respond(r *http.Response) error {
return rf(r)
}
// RespondDecorator takes and possibly decorates, by wrapping, a Responder. Decorators may react to
// the http.Response and pass it along or, first, pass the http.Response along then react.
type RespondDecorator func(Responder) Responder
// CreateResponder creates, decorates, and returns a Responder. Without decorators, the returned
// Responder returns the passed http.Response unmodified. Responders may or may not be safe to share
// and re-used: It depends on the applied decorators. For example, a standard decorator that closes
// the response body is fine to share whereas a decorator that reads the body into a passed struct
// is not.
//
// To prevent memory leaks, ensure that at least one Responder closes the response body.
func CreateResponder(decorators ...RespondDecorator) Responder {
return DecorateResponder(
Responder(ResponderFunc(func(r *http.Response) error { return nil })),
decorators...)
}
// DecorateResponder accepts a Responder and a, possibly empty, set of RespondDecorators, which it
// applies to the Responder. Decorators are applied in the order received, but their affect upon the
// request depends on whether they are a pre-decorator (react to the http.Response and then pass it
// along) or a post-decorator (pass the http.Response along and then react).
func DecorateResponder(r Responder, decorators ...RespondDecorator) Responder {
for _, decorate := range decorators {
r = decorate(r)
}
return r
}
// Respond accepts an http.Response and a, possibly empty, set of RespondDecorators.
// It creates a Responder from the decorators it then applies to the passed http.Response.
func Respond(r *http.Response, decorators ...RespondDecorator) error {
if r == nil {
return nil
}
return CreateResponder(decorators...).Respond(r)
}
// ByIgnoring returns a RespondDecorator that ignores the passed http.Response passing it unexamined
// to the next RespondDecorator.
func ByIgnoring() RespondDecorator {
return func(r Responder) Responder {
return ResponderFunc(func(resp *http.Response) error {
return r.Respond(resp)
})
}
}
// ByCopying copies the contents of the http.Response Body into the passed bytes.Buffer as
// the Body is read.
func ByCopying(b *bytes.Buffer) RespondDecorator {
return func(r Responder) Responder {
return ResponderFunc(func(resp *http.Response) error {
err := r.Respond(resp)
if err == nil && resp != nil && resp.Body != nil {
resp.Body = TeeReadCloser(resp.Body, b)
}
return err
})
}
}
// ByDiscardingBody returns a RespondDecorator that first invokes the passed Responder after which
// it copies the remaining bytes (if any) in the response body to ioutil.Discard. Since the passed
// Responder is invoked prior to discarding the response body, the decorator may occur anywhere
// within the set.
func ByDiscardingBody() RespondDecorator {
return func(r Responder) Responder {
return ResponderFunc(func(resp *http.Response) error {
err := r.Respond(resp)
if err == nil && resp != nil && resp.Body != nil {
if _, err := io.Copy(ioutil.Discard, resp.Body); err != nil {
return fmt.Errorf("Error discarding the response body: %v", err)
}
}
return err
})
}
}
// ByClosing returns a RespondDecorator that first invokes the passed Responder after which it
// closes the response body. Since the passed Responder is invoked prior to closing the response
// body, the decorator may occur anywhere within the set.
func ByClosing() RespondDecorator {
return func(r Responder) Responder {
return ResponderFunc(func(resp *http.Response) error {
err := r.Respond(resp)
if resp != nil && resp.Body != nil {
if err := resp.Body.Close(); err != nil {
return fmt.Errorf("Error closing the response body: %v", err)
}
}
return err
})
}
}
// ByClosingIfError returns a RespondDecorator that first invokes the passed Responder after which
// it closes the response if the passed Responder returns an error and the response body exists.
func ByClosingIfError() RespondDecorator {
return func(r Responder) Responder {
return ResponderFunc(func(resp *http.Response) error {
err := r.Respond(resp)
if err != nil && resp != nil && resp.Body != nil {
if err := resp.Body.Close(); err != nil {
return fmt.Errorf("Error closing the response body: %v", err)
}
}
return err
})
}
}
// ByUnmarshallingJSON returns a RespondDecorator that decodes a JSON document returned in the
// response Body into the value pointed to by v.
func ByUnmarshallingJSON(v interface{}) RespondDecorator {
return func(r Responder) Responder {
return ResponderFunc(func(resp *http.Response) error {
err := r.Respond(resp)
if err == nil {
b, errInner := ioutil.ReadAll(resp.Body)
// Some responses might include a BOM, remove for successful unmarshalling
b = bytes.TrimPrefix(b, []byte("\xef\xbb\xbf"))
if errInner != nil {
err = fmt.Errorf("Error occurred reading http.Response#Body - Error = '%v'", errInner)
} else if len(strings.Trim(string(b), " ")) > 0 {
errInner = json.Unmarshal(b, v)
if errInner != nil {
err = fmt.Errorf("Error occurred unmarshalling JSON - Error = '%v' JSON = '%s'", errInner, string(b))
}
}
}
return err
})
}
}
// ByUnmarshallingXML returns a RespondDecorator that decodes a XML document returned in the
// response Body into the value pointed to by v.
func ByUnmarshallingXML(v interface{}) RespondDecorator {
return func(r Responder) Responder {
return ResponderFunc(func(resp *http.Response) error {
err := r.Respond(resp)
if err == nil {
b, errInner := ioutil.ReadAll(resp.Body)
if errInner != nil {
err = fmt.Errorf("Error occurred reading http.Response#Body - Error = '%v'", errInner)
} else {
errInner = xml.Unmarshal(b, v)
if errInner != nil {
err = fmt.Errorf("Error occurred unmarshalling Xml - Error = '%v' Xml = '%s'", errInner, string(b))
}
}
}
return err
})
}
}
// WithErrorUnlessStatusCode returns a RespondDecorator that emits an error unless the response
// StatusCode is among the set passed. On error, response body is fully read into a buffer and
// presented in the returned error, as well as in the response body.
func WithErrorUnlessStatusCode(codes ...int) RespondDecorator {
return func(r Responder) Responder {
return ResponderFunc(func(resp *http.Response) error {
err := r.Respond(resp)
if err == nil && !ResponseHasStatusCode(resp, codes...) {
derr := NewErrorWithResponse("autorest", "WithErrorUnlessStatusCode", resp, "%v %v failed with %s",
resp.Request.Method,
resp.Request.URL,
resp.Status)
if resp.Body != nil {
defer resp.Body.Close()
b, _ := ioutil.ReadAll(resp.Body)
derr.ServiceError = b
resp.Body = ioutil.NopCloser(bytes.NewReader(b))
}
err = derr
}
return err
})
}
}
// WithErrorUnlessOK returns a RespondDecorator that emits an error if the response StatusCode is
// anything other than HTTP 200.
func WithErrorUnlessOK() RespondDecorator {
return WithErrorUnlessStatusCode(http.StatusOK)
}
// ExtractHeader extracts all values of the specified header from the http.Response. It returns an
// empty string slice if the passed http.Response is nil or the header does not exist.
func ExtractHeader(header string, resp *http.Response) []string {
if resp != nil && resp.Header != nil {
return resp.Header[http.CanonicalHeaderKey(header)]
}
return nil
}
// ExtractHeaderValue extracts the first value of the specified header from the http.Response. It
// returns an empty string if the passed http.Response is nil or the header does not exist.
func ExtractHeaderValue(header string, resp *http.Response) string {
h := ExtractHeader(header, resp)
if len(h) > 0 {
return h[0]
}
return ""
}

View File

@@ -1,665 +0,0 @@
package autorest
// Copyright 2017 Microsoft Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
import (
"bytes"
"encoding/json"
"fmt"
"io/ioutil"
"net/http"
"reflect"
"strings"
"testing"
"github.com/Azure/go-autorest/autorest/mocks"
)
func ExampleWithErrorUnlessOK() {
r := mocks.NewResponse()
r.Request = mocks.NewRequest()
// Respond and leave the response body open (for a subsequent responder to close)
err := Respond(r,
WithErrorUnlessOK(),
ByDiscardingBody(),
ByClosingIfError())
if err == nil {
fmt.Printf("%s of %s returned HTTP 200", r.Request.Method, r.Request.URL)
// Complete handling the response and close the body
Respond(r,
ByDiscardingBody(),
ByClosing())
}
// Output: GET of https://microsoft.com/a/b/c/ returned HTTP 200
}
func ExampleByUnmarshallingJSON() {
c := `
{
"name" : "Rob Pike",
"age" : 42
}
`
type V struct {
Name string `json:"name"`
Age int `json:"age"`
}
v := &V{}
Respond(mocks.NewResponseWithContent(c),
ByUnmarshallingJSON(v),
ByClosing())
fmt.Printf("%s is %d years old\n", v.Name, v.Age)
// Output: Rob Pike is 42 years old
}
func ExampleByUnmarshallingXML() {
c := `<?xml version="1.0" encoding="UTF-8"?>
<Person>
<Name>Rob Pike</Name>
<Age>42</Age>
</Person>`
type V struct {
Name string `xml:"Name"`
Age int `xml:"Age"`
}
v := &V{}
Respond(mocks.NewResponseWithContent(c),
ByUnmarshallingXML(v),
ByClosing())
fmt.Printf("%s is %d years old\n", v.Name, v.Age)
// Output: Rob Pike is 42 years old
}
func TestCreateResponderDoesNotModify(t *testing.T) {
r1 := mocks.NewResponse()
r2 := mocks.NewResponse()
p := CreateResponder()
err := p.Respond(r1)
if err != nil {
t.Fatalf("autorest: CreateResponder failed (%v)", err)
}
if !reflect.DeepEqual(r1, r2) {
t.Fatalf("autorest: CreateResponder without decorators modified the response")
}
}
func TestCreateResponderRunsDecoratorsInOrder(t *testing.T) {
s := ""
d := func(n int) RespondDecorator {
return func(r Responder) Responder {
return ResponderFunc(func(resp *http.Response) error {
err := r.Respond(resp)
if err == nil {
s += fmt.Sprintf("%d", n)
}
return err
})
}
}
p := CreateResponder(d(1), d(2), d(3))
err := p.Respond(&http.Response{})
if err != nil {
t.Fatalf("autorest: Respond failed (%v)", err)
}
if s != "123" {
t.Fatalf("autorest: CreateResponder invoked decorators in an incorrect order; expected '123', received '%s'", s)
}
}
func TestByIgnoring(t *testing.T) {
r := mocks.NewResponse()
Respond(r,
(func() RespondDecorator {
return func(r Responder) Responder {
return ResponderFunc(func(r2 *http.Response) error {
r1 := mocks.NewResponse()
if !reflect.DeepEqual(r1, r2) {
t.Fatalf("autorest: ByIgnoring modified the HTTP Response -- received %v, expected %v", r2, r1)
}
return nil
})
}
})(),
ByIgnoring(),
ByClosing())
}
func TestByCopying_Copies(t *testing.T) {
r := mocks.NewResponseWithContent(jsonT)
b := &bytes.Buffer{}
err := Respond(r,
ByCopying(b),
ByUnmarshallingJSON(&mocks.T{}),
ByClosing())
if err != nil {
t.Fatalf("autorest: ByCopying returned an unexpected error -- %v", err)
}
if b.String() != jsonT {
t.Fatalf("autorest: ByCopying failed to copy the bytes read")
}
}
func TestByCopying_ReturnsNestedErrors(t *testing.T) {
r := mocks.NewResponseWithContent(jsonT)
r.Body.Close()
err := Respond(r,
ByCopying(&bytes.Buffer{}),
ByUnmarshallingJSON(&mocks.T{}),
ByClosing())
if err == nil {
t.Fatalf("autorest: ByCopying failed to return the expected error")
}
}
func TestByCopying_AcceptsNilReponse(t *testing.T) {
r := mocks.NewResponse()
Respond(r,
(func() RespondDecorator {
return func(r Responder) Responder {
return ResponderFunc(func(resp *http.Response) error {
resp.Body.Close()
r.Respond(nil)
return nil
})
}
})(),
ByCopying(&bytes.Buffer{}))
}
func TestByCopying_AcceptsNilBody(t *testing.T) {
r := mocks.NewResponse()
Respond(r,
(func() RespondDecorator {
return func(r Responder) Responder {
return ResponderFunc(func(resp *http.Response) error {
resp.Body.Close()
resp.Body = nil
r.Respond(resp)
return nil
})
}
})(),
ByCopying(&bytes.Buffer{}))
}
func TestByClosing(t *testing.T) {
r := mocks.NewResponse()
err := Respond(r, ByClosing())
if err != nil {
t.Fatalf("autorest: ByClosing failed (%v)", err)
}
if r.Body.(*mocks.Body).IsOpen() {
t.Fatalf("autorest: ByClosing did not close the response body")
}
}
func TestByClosingAcceptsNilResponse(t *testing.T) {
r := mocks.NewResponse()
Respond(r,
(func() RespondDecorator {
return func(r Responder) Responder {
return ResponderFunc(func(resp *http.Response) error {
resp.Body.Close()
r.Respond(nil)
return nil
})
}
})(),
ByClosing())
}
func TestByClosingAcceptsNilBody(t *testing.T) {
r := mocks.NewResponse()
Respond(r,
(func() RespondDecorator {
return func(r Responder) Responder {
return ResponderFunc(func(resp *http.Response) error {
resp.Body.Close()
resp.Body = nil
r.Respond(resp)
return nil
})
}
})(),
ByClosing())
}
func TestByClosingClosesEvenAfterErrors(t *testing.T) {
var e error
r := mocks.NewResponse()
Respond(r,
withErrorRespondDecorator(&e),
ByClosing())
if r.Body.(*mocks.Body).IsOpen() {
t.Fatalf("autorest: ByClosing did not close the response body after an error occurred")
}
}
func TestByClosingClosesReturnsNestedErrors(t *testing.T) {
var e error
r := mocks.NewResponse()
err := Respond(r,
withErrorRespondDecorator(&e),
ByClosing())
if err == nil || !reflect.DeepEqual(e, err) {
t.Fatalf("autorest: ByClosing failed to return a nested error")
}
}
func TestByClosingIfErrorAcceptsNilResponse(t *testing.T) {
var e error
r := mocks.NewResponse()
Respond(r,
withErrorRespondDecorator(&e),
(func() RespondDecorator {
return func(r Responder) Responder {
return ResponderFunc(func(resp *http.Response) error {
resp.Body.Close()
r.Respond(nil)
return nil
})
}
})(),
ByClosingIfError())
}
func TestByClosingIfErrorAcceptsNilBody(t *testing.T) {
var e error
r := mocks.NewResponse()
Respond(r,
withErrorRespondDecorator(&e),
(func() RespondDecorator {
return func(r Responder) Responder {
return ResponderFunc(func(resp *http.Response) error {
resp.Body.Close()
resp.Body = nil
r.Respond(resp)
return nil
})
}
})(),
ByClosingIfError())
}
func TestByClosingIfErrorClosesIfAnErrorOccurs(t *testing.T) {
var e error
r := mocks.NewResponse()
Respond(r,
withErrorRespondDecorator(&e),
ByClosingIfError())
if r.Body.(*mocks.Body).IsOpen() {
t.Fatalf("autorest: ByClosingIfError did not close the response body after an error occurred")
}
}
func TestByClosingIfErrorDoesNotClosesIfNoErrorOccurs(t *testing.T) {
r := mocks.NewResponse()
Respond(r,
ByClosingIfError())
if !r.Body.(*mocks.Body).IsOpen() {
t.Fatalf("autorest: ByClosingIfError closed the response body even though no error occurred")
}
}
func TestByDiscardingBody(t *testing.T) {
r := mocks.NewResponse()
err := Respond(r,
ByDiscardingBody())
if err != nil {
t.Fatalf("autorest: ByDiscardingBody failed (%v)", err)
}
buf, err := ioutil.ReadAll(r.Body)
if err != nil {
t.Fatalf("autorest: Reading result of ByDiscardingBody failed (%v)", err)
}
if len(buf) != 0 {
t.Logf("autorest: Body was not empty after calling ByDiscardingBody.")
t.Fail()
}
}
func TestByDiscardingBodyAcceptsNilResponse(t *testing.T) {
var e error
r := mocks.NewResponse()
Respond(r,
withErrorRespondDecorator(&e),
(func() RespondDecorator {
return func(r Responder) Responder {
return ResponderFunc(func(resp *http.Response) error {
resp.Body.Close()
r.Respond(nil)
return nil
})
}
})(),
ByDiscardingBody())
}
func TestByDiscardingBodyAcceptsNilBody(t *testing.T) {
var e error
r := mocks.NewResponse()
Respond(r,
withErrorRespondDecorator(&e),
(func() RespondDecorator {
return func(r Responder) Responder {
return ResponderFunc(func(resp *http.Response) error {
resp.Body.Close()
resp.Body = nil
r.Respond(resp)
return nil
})
}
})(),
ByDiscardingBody())
}
func TestByUnmarshallingJSON(t *testing.T) {
v := &mocks.T{}
r := mocks.NewResponseWithContent(jsonT)
err := Respond(r,
ByUnmarshallingJSON(v),
ByClosing())
if err != nil {
t.Fatalf("autorest: ByUnmarshallingJSON failed (%v)", err)
}
if v.Name != "Rob Pike" || v.Age != 42 {
t.Fatalf("autorest: ByUnmarshallingJSON failed to properly unmarshal")
}
}
func TestByUnmarshallingJSON_HandlesReadErrors(t *testing.T) {
v := &mocks.T{}
r := mocks.NewResponseWithContent(jsonT)
r.Body.(*mocks.Body).Close()
err := Respond(r,
ByUnmarshallingJSON(v),
ByClosing())
if err == nil {
t.Fatalf("autorest: ByUnmarshallingJSON failed to receive / respond to read error")
}
}
func TestByUnmarshallingJSONIncludesJSONInErrors(t *testing.T) {
v := &mocks.T{}
j := jsonT[0 : len(jsonT)-2]
r := mocks.NewResponseWithContent(j)
err := Respond(r,
ByUnmarshallingJSON(v),
ByClosing())
if err == nil || !strings.Contains(err.Error(), j) {
t.Fatalf("autorest: ByUnmarshallingJSON failed to return JSON in error (%v)", err)
}
}
func TestByUnmarshallingJSONEmptyInput(t *testing.T) {
v := &mocks.T{}
r := mocks.NewResponseWithContent(``)
err := Respond(r,
ByUnmarshallingJSON(v),
ByClosing())
if err != nil {
t.Fatalf("autorest: ByUnmarshallingJSON failed to return nil in case of empty JSON (%v)", err)
}
}
func TestByUnmarshallingXML(t *testing.T) {
v := &mocks.T{}
r := mocks.NewResponseWithContent(xmlT)
err := Respond(r,
ByUnmarshallingXML(v),
ByClosing())
if err != nil {
t.Fatalf("autorest: ByUnmarshallingXML failed (%v)", err)
}
if v.Name != "Rob Pike" || v.Age != 42 {
t.Fatalf("autorest: ByUnmarshallingXML failed to properly unmarshal")
}
}
func TestByUnmarshallingXML_HandlesReadErrors(t *testing.T) {
v := &mocks.T{}
r := mocks.NewResponseWithContent(xmlT)
r.Body.(*mocks.Body).Close()
err := Respond(r,
ByUnmarshallingXML(v),
ByClosing())
if err == nil {
t.Fatalf("autorest: ByUnmarshallingXML failed to receive / respond to read error")
}
}
func TestByUnmarshallingXMLIncludesXMLInErrors(t *testing.T) {
v := &mocks.T{}
x := xmlT[0 : len(xmlT)-2]
r := mocks.NewResponseWithContent(x)
err := Respond(r,
ByUnmarshallingXML(v),
ByClosing())
if err == nil || !strings.Contains(err.Error(), x) {
t.Fatalf("autorest: ByUnmarshallingXML failed to return XML in error (%v)", err)
}
}
func TestRespondAcceptsNullResponse(t *testing.T) {
err := Respond(nil)
if err != nil {
t.Fatalf("autorest: Respond returned an unexpected error when given a null Response (%v)", err)
}
}
func TestWithErrorUnlessStatusCodeOKResponse(t *testing.T) {
v := &mocks.T{}
r := mocks.NewResponseWithContent(jsonT)
err := Respond(r,
WithErrorUnlessStatusCode(http.StatusOK),
ByUnmarshallingJSON(v),
ByClosing())
if err != nil {
t.Fatalf("autorest: WithErrorUnlessStatusCode(http.StatusOK) failed on okay response. (%v)", err)
}
if v.Name != "Rob Pike" || v.Age != 42 {
t.Fatalf("autorest: WithErrorUnlessStatusCode(http.StatusOK) corrupted the response body of okay response.")
}
}
func TesWithErrorUnlessStatusCodeErrorResponse(t *testing.T) {
v := &mocks.T{}
e := &mocks.T{}
r := mocks.NewResponseWithContent(jsonT)
r.Status = "400 BadRequest"
r.StatusCode = http.StatusBadRequest
err := Respond(r,
WithErrorUnlessStatusCode(http.StatusOK),
ByUnmarshallingJSON(v),
ByClosing())
if err == nil {
t.Fatal("autorest: WithErrorUnlessStatusCode(http.StatusOK) did not return error, on a response to a bad request.")
}
var errorRespBody []byte
if derr, ok := err.(DetailedError); !ok {
t.Fatalf("autorest: WithErrorUnlessStatusCode(http.StatusOK) got wrong error type : %T, expected: DetailedError, on a response to a bad request.", err)
} else {
errorRespBody = derr.ServiceError
}
if errorRespBody == nil {
t.Fatalf("autorest: WithErrorUnlessStatusCode(http.StatusOK) ServiceError not returned in DetailedError on a response to a bad request.")
}
err = json.Unmarshal(errorRespBody, e)
if err != nil {
t.Fatalf("autorest: WithErrorUnlessStatusCode(http.StatusOK) cannot parse error returned in ServiceError into json. %v", err)
}
expected := &mocks.T{Name: "Rob Pike", Age: 42}
if e != expected {
t.Fatalf("autorest: WithErrorUnlessStatusCode(http.StatusOK wrong value from parsed ServiceError: got=%#v expected=%#v", e, expected)
}
}
func TestWithErrorUnlessStatusCode(t *testing.T) {
r := mocks.NewResponse()
r.Request = mocks.NewRequest()
r.Status = "400 BadRequest"
r.StatusCode = http.StatusBadRequest
err := Respond(r,
WithErrorUnlessStatusCode(http.StatusBadRequest, http.StatusUnauthorized, http.StatusInternalServerError),
ByClosingIfError())
if err != nil {
t.Fatalf("autorest: WithErrorUnlessStatusCode returned an error (%v) for an acceptable status code (%s)", err, r.Status)
}
}
func TestWithErrorUnlessStatusCodeEmitsErrorForUnacceptableStatusCode(t *testing.T) {
r := mocks.NewResponse()
r.Request = mocks.NewRequest()
r.Status = "400 BadRequest"
r.StatusCode = http.StatusBadRequest
err := Respond(r,
WithErrorUnlessStatusCode(http.StatusOK, http.StatusUnauthorized, http.StatusInternalServerError),
ByClosingIfError())
if err == nil {
t.Fatalf("autorest: WithErrorUnlessStatusCode failed to return an error for an unacceptable status code (%s)", r.Status)
}
}
func TestWithErrorUnlessOK(t *testing.T) {
r := mocks.NewResponse()
r.Request = mocks.NewRequest()
err := Respond(r,
WithErrorUnlessOK(),
ByClosingIfError())
if err != nil {
t.Fatalf("autorest: WithErrorUnlessOK returned an error for OK status code (%v)", err)
}
}
func TestWithErrorUnlessOKEmitsErrorIfNotOK(t *testing.T) {
r := mocks.NewResponse()
r.Request = mocks.NewRequest()
r.Status = "400 BadRequest"
r.StatusCode = http.StatusBadRequest
err := Respond(r,
WithErrorUnlessOK(),
ByClosingIfError())
if err == nil {
t.Fatalf("autorest: WithErrorUnlessOK failed to return an error for a non-OK status code (%v)", err)
}
}
func TestExtractHeader(t *testing.T) {
r := mocks.NewResponse()
v := []string{"v1", "v2", "v3"}
mocks.SetResponseHeaderValues(r, mocks.TestHeader, v)
if !reflect.DeepEqual(ExtractHeader(mocks.TestHeader, r), v) {
t.Fatalf("autorest: ExtractHeader failed to retrieve the expected header -- expected [%s]%v, received [%s]%v",
mocks.TestHeader, v, mocks.TestHeader, ExtractHeader(mocks.TestHeader, r))
}
}
func TestExtractHeaderHandlesMissingHeader(t *testing.T) {
var v []string
r := mocks.NewResponse()
if !reflect.DeepEqual(ExtractHeader(mocks.TestHeader, r), v) {
t.Fatalf("autorest: ExtractHeader failed to handle a missing header -- expected %v, received %v",
v, ExtractHeader(mocks.TestHeader, r))
}
}
func TestExtractHeaderValue(t *testing.T) {
r := mocks.NewResponse()
v := "v1"
mocks.SetResponseHeader(r, mocks.TestHeader, v)
if ExtractHeaderValue(mocks.TestHeader, r) != v {
t.Fatalf("autorest: ExtractHeader failed to retrieve the expected header -- expected [%s]%v, received [%s]%v",
mocks.TestHeader, v, mocks.TestHeader, ExtractHeaderValue(mocks.TestHeader, r))
}
}
func TestExtractHeaderValueHandlesMissingHeader(t *testing.T) {
r := mocks.NewResponse()
v := ""
if ExtractHeaderValue(mocks.TestHeader, r) != v {
t.Fatalf("autorest: ExtractHeader failed to retrieve the expected header -- expected [%s]%v, received [%s]%v",
mocks.TestHeader, v, mocks.TestHeader, ExtractHeaderValue(mocks.TestHeader, r))
}
}
func TestExtractHeaderValueRetrievesFirstValue(t *testing.T) {
r := mocks.NewResponse()
v := []string{"v1", "v2", "v3"}
mocks.SetResponseHeaderValues(r, mocks.TestHeader, v)
if ExtractHeaderValue(mocks.TestHeader, r) != v[0] {
t.Fatalf("autorest: ExtractHeader failed to retrieve the expected header -- expected [%s]%v, received [%s]%v",
mocks.TestHeader, v[0], mocks.TestHeader, ExtractHeaderValue(mocks.TestHeader, r))
}
}

View File

@@ -1,52 +0,0 @@
package autorest
// Copyright 2017 Microsoft Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
import (
"bytes"
"io"
"io/ioutil"
"net/http"
)
// NewRetriableRequest returns a wrapper around an HTTP request that support retry logic.
func NewRetriableRequest(req *http.Request) *RetriableRequest {
return &RetriableRequest{req: req}
}
// Request returns the wrapped HTTP request.
func (rr *RetriableRequest) Request() *http.Request {
return rr.req
}
func (rr *RetriableRequest) prepareFromByteReader() (err error) {
// fall back to making a copy (only do this once)
b := []byte{}
if rr.req.ContentLength > 0 {
b = make([]byte, rr.req.ContentLength)
_, err = io.ReadFull(rr.req.Body, b)
if err != nil {
return err
}
} else {
b, err = ioutil.ReadAll(rr.req.Body)
if err != nil {
return err
}
}
rr.br = bytes.NewReader(b)
rr.req.Body = ioutil.NopCloser(rr.br)
return err
}

View File

@@ -1,54 +0,0 @@
// +build !go1.8
// Copyright 2017 Microsoft Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package autorest
import (
"bytes"
"io/ioutil"
"net/http"
)
// RetriableRequest provides facilities for retrying an HTTP request.
type RetriableRequest struct {
req *http.Request
br *bytes.Reader
}
// Prepare signals that the request is about to be sent.
func (rr *RetriableRequest) Prepare() (err error) {
// preserve the request body; this is to support retry logic as
// the underlying transport will always close the reqeust body
if rr.req.Body != nil {
if rr.br != nil {
_, err = rr.br.Seek(0, 0 /*io.SeekStart*/)
rr.req.Body = ioutil.NopCloser(rr.br)
}
if err != nil {
return err
}
if rr.br == nil {
// fall back to making a copy (only do this once)
err = rr.prepareFromByteReader()
}
}
return err
}
func removeRequestBody(req *http.Request) {
req.Body = nil
req.ContentLength = 0
}

View File

@@ -1,66 +0,0 @@
// +build go1.8
// Copyright 2017 Microsoft Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package autorest
import (
"bytes"
"io"
"io/ioutil"
"net/http"
)
// RetriableRequest provides facilities for retrying an HTTP request.
type RetriableRequest struct {
req *http.Request
rc io.ReadCloser
br *bytes.Reader
}
// Prepare signals that the request is about to be sent.
func (rr *RetriableRequest) Prepare() (err error) {
// preserve the request body; this is to support retry logic as
// the underlying transport will always close the reqeust body
if rr.req.Body != nil {
if rr.rc != nil {
rr.req.Body = rr.rc
} else if rr.br != nil {
_, err = rr.br.Seek(0, io.SeekStart)
rr.req.Body = ioutil.NopCloser(rr.br)
}
if err != nil {
return err
}
if rr.req.GetBody != nil {
// this will allow us to preserve the body without having to
// make a copy. note we need to do this on each iteration
rr.rc, err = rr.req.GetBody()
if err != nil {
return err
}
} else if rr.br == nil {
// fall back to making a copy (only do this once)
err = rr.prepareFromByteReader()
}
}
return err
}
func removeRequestBody(req *http.Request) {
req.Body = nil
req.GetBody = nil
req.ContentLength = 0
}

View File

@@ -1,311 +0,0 @@
package autorest
// Copyright 2017 Microsoft Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
import (
"fmt"
"log"
"math"
"net/http"
"strconv"
"time"
)
// Sender is the interface that wraps the Do method to send HTTP requests.
//
// The standard http.Client conforms to this interface.
type Sender interface {
Do(*http.Request) (*http.Response, error)
}
// SenderFunc is a method that implements the Sender interface.
type SenderFunc func(*http.Request) (*http.Response, error)
// Do implements the Sender interface on SenderFunc.
func (sf SenderFunc) Do(r *http.Request) (*http.Response, error) {
return sf(r)
}
// SendDecorator takes and possibily decorates, by wrapping, a Sender. Decorators may affect the
// http.Request and pass it along or, first, pass the http.Request along then react to the
// http.Response result.
type SendDecorator func(Sender) Sender
// CreateSender creates, decorates, and returns, as a Sender, the default http.Client.
func CreateSender(decorators ...SendDecorator) Sender {
return DecorateSender(&http.Client{}, decorators...)
}
// DecorateSender accepts a Sender and a, possibly empty, set of SendDecorators, which is applies to
// the Sender. Decorators are applied in the order received, but their affect upon the request
// depends on whether they are a pre-decorator (change the http.Request and then pass it along) or a
// post-decorator (pass the http.Request along and react to the results in http.Response).
func DecorateSender(s Sender, decorators ...SendDecorator) Sender {
for _, decorate := range decorators {
s = decorate(s)
}
return s
}
// Send sends, by means of the default http.Client, the passed http.Request, returning the
// http.Response and possible error. It also accepts a, possibly empty, set of SendDecorators which
// it will apply the http.Client before invoking the Do method.
//
// Send is a convenience method and not recommended for production. Advanced users should use
// SendWithSender, passing and sharing their own Sender (e.g., instance of http.Client).
//
// Send will not poll or retry requests.
func Send(r *http.Request, decorators ...SendDecorator) (*http.Response, error) {
return SendWithSender(&http.Client{}, r, decorators...)
}
// SendWithSender sends the passed http.Request, through the provided Sender, returning the
// http.Response and possible error. It also accepts a, possibly empty, set of SendDecorators which
// it will apply the http.Client before invoking the Do method.
//
// SendWithSender will not poll or retry requests.
func SendWithSender(s Sender, r *http.Request, decorators ...SendDecorator) (*http.Response, error) {
return DecorateSender(s, decorators...).Do(r)
}
// AfterDelay returns a SendDecorator that delays for the passed time.Duration before
// invoking the Sender. The delay may be terminated by closing the optional channel on the
// http.Request. If canceled, no further Senders are invoked.
func AfterDelay(d time.Duration) SendDecorator {
return func(s Sender) Sender {
return SenderFunc(func(r *http.Request) (*http.Response, error) {
if !DelayForBackoff(d, 0, r.Cancel) {
return nil, fmt.Errorf("autorest: AfterDelay canceled before full delay")
}
return s.Do(r)
})
}
}
// AsIs returns a SendDecorator that invokes the passed Sender without modifying the http.Request.
func AsIs() SendDecorator {
return func(s Sender) Sender {
return SenderFunc(func(r *http.Request) (*http.Response, error) {
return s.Do(r)
})
}
}
// DoCloseIfError returns a SendDecorator that first invokes the passed Sender after which
// it closes the response if the passed Sender returns an error and the response body exists.
func DoCloseIfError() SendDecorator {
return func(s Sender) Sender {
return SenderFunc(func(r *http.Request) (*http.Response, error) {
resp, err := s.Do(r)
if err != nil {
Respond(resp, ByDiscardingBody(), ByClosing())
}
return resp, err
})
}
}
// DoErrorIfStatusCode returns a SendDecorator that emits an error if the response StatusCode is
// among the set passed. Since these are artificial errors, the response body may still require
// closing.
func DoErrorIfStatusCode(codes ...int) SendDecorator {
return func(s Sender) Sender {
return SenderFunc(func(r *http.Request) (*http.Response, error) {
resp, err := s.Do(r)
if err == nil && ResponseHasStatusCode(resp, codes...) {
err = NewErrorWithResponse("autorest", "DoErrorIfStatusCode", resp, "%v %v failed with %s",
resp.Request.Method,
resp.Request.URL,
resp.Status)
}
return resp, err
})
}
}
// DoErrorUnlessStatusCode returns a SendDecorator that emits an error unless the response
// StatusCode is among the set passed. Since these are artificial errors, the response body
// may still require closing.
func DoErrorUnlessStatusCode(codes ...int) SendDecorator {
return func(s Sender) Sender {
return SenderFunc(func(r *http.Request) (*http.Response, error) {
resp, err := s.Do(r)
if err == nil && !ResponseHasStatusCode(resp, codes...) {
err = NewErrorWithResponse("autorest", "DoErrorUnlessStatusCode", resp, "%v %v failed with %s",
resp.Request.Method,
resp.Request.URL,
resp.Status)
}
return resp, err
})
}
}
// DoPollForStatusCodes returns a SendDecorator that polls if the http.Response contains one of the
// passed status codes. It expects the http.Response to contain a Location header providing the
// URL at which to poll (using GET) and will poll until the time passed is equal to or greater than
// the supplied duration. It will delay between requests for the duration specified in the
// RetryAfter header or, if the header is absent, the passed delay. Polling may be canceled by
// closing the optional channel on the http.Request.
func DoPollForStatusCodes(duration time.Duration, delay time.Duration, codes ...int) SendDecorator {
return func(s Sender) Sender {
return SenderFunc(func(r *http.Request) (resp *http.Response, err error) {
resp, err = s.Do(r)
if err == nil && ResponseHasStatusCode(resp, codes...) {
r, err = NewPollingRequest(resp, r.Cancel)
for err == nil && ResponseHasStatusCode(resp, codes...) {
Respond(resp,
ByDiscardingBody(),
ByClosing())
resp, err = SendWithSender(s, r,
AfterDelay(GetRetryAfter(resp, delay)))
}
}
return resp, err
})
}
}
// DoRetryForAttempts returns a SendDecorator that retries a failed request for up to the specified
// number of attempts, exponentially backing off between requests using the supplied backoff
// time.Duration (which may be zero). Retrying may be canceled by closing the optional channel on
// the http.Request.
func DoRetryForAttempts(attempts int, backoff time.Duration) SendDecorator {
return func(s Sender) Sender {
return SenderFunc(func(r *http.Request) (resp *http.Response, err error) {
rr := NewRetriableRequest(r)
for attempt := 0; attempt < attempts; attempt++ {
err = rr.Prepare()
if err != nil {
return resp, err
}
resp, err = s.Do(rr.Request())
if err == nil {
return resp, err
}
DelayForBackoff(backoff, attempt, r.Cancel)
}
return resp, err
})
}
}
// DoRetryForStatusCodes returns a SendDecorator that retries for specified statusCodes for up to the specified
// number of attempts, exponentially backing off between requests using the supplied backoff
// time.Duration (which may be zero). Retrying may be canceled by closing the optional channel on
// the http.Request.
func DoRetryForStatusCodes(attempts int, backoff time.Duration, codes ...int) SendDecorator {
return func(s Sender) Sender {
return SenderFunc(func(r *http.Request) (resp *http.Response, err error) {
rr := NewRetriableRequest(r)
// Increment to add the first call (attempts denotes number of retries)
attempts++
for attempt := 0; attempt < attempts; attempt++ {
err = rr.Prepare()
if err != nil {
return resp, err
}
resp, err = s.Do(rr.Request())
// we want to retry if err is not nil (e.g. transient network failure)
if err == nil && !ResponseHasStatusCode(resp, codes...) {
return resp, err
}
delayed := DelayWithRetryAfter(resp, r.Cancel)
if !delayed {
DelayForBackoff(backoff, attempt, r.Cancel)
}
}
return resp, err
})
}
}
// DelayWithRetryAfter invokes time.After for the duration specified in the "Retry-After" header in
// responses with status code 429
func DelayWithRetryAfter(resp *http.Response, cancel <-chan struct{}) bool {
if resp == nil {
return false
}
retryAfter, _ := strconv.Atoi(resp.Header.Get("Retry-After"))
if resp.StatusCode == http.StatusTooManyRequests && retryAfter > 0 {
select {
case <-time.After(time.Duration(retryAfter) * time.Second):
return true
case <-cancel:
return false
}
}
return false
}
// DoRetryForDuration returns a SendDecorator that retries the request until the total time is equal
// to or greater than the specified duration, exponentially backing off between requests using the
// supplied backoff time.Duration (which may be zero). Retrying may be canceled by closing the
// optional channel on the http.Request.
func DoRetryForDuration(d time.Duration, backoff time.Duration) SendDecorator {
return func(s Sender) Sender {
return SenderFunc(func(r *http.Request) (resp *http.Response, err error) {
rr := NewRetriableRequest(r)
end := time.Now().Add(d)
for attempt := 0; time.Now().Before(end); attempt++ {
err = rr.Prepare()
if err != nil {
return resp, err
}
resp, err = s.Do(rr.Request())
if err == nil {
return resp, err
}
DelayForBackoff(backoff, attempt, r.Cancel)
}
return resp, err
})
}
}
// WithLogging returns a SendDecorator that implements simple before and after logging of the
// request.
func WithLogging(logger *log.Logger) SendDecorator {
return func(s Sender) Sender {
return SenderFunc(func(r *http.Request) (*http.Response, error) {
logger.Printf("Sending %s %s", r.Method, r.URL)
resp, err := s.Do(r)
if err != nil {
logger.Printf("%s %s received error '%v'", r.Method, r.URL, err)
} else {
logger.Printf("%s %s received %s", r.Method, r.URL, resp.Status)
}
return resp, err
})
}
}
// DelayForBackoff invokes time.After for the supplied backoff duration raised to the power of
// passed attempt (i.e., an exponential backoff delay). Backoff duration is in seconds and can set
// to zero for no delay. The delay may be canceled by closing the passed channel. If terminated early,
// returns false.
// Note: Passing attempt 1 will result in doubling "backoff" duration. Treat this as a zero-based attempt
// count.
func DelayForBackoff(backoff time.Duration, attempt int, cancel <-chan struct{}) bool {
select {
case <-time.After(time.Duration(backoff.Seconds()*math.Pow(2, float64(attempt))) * time.Second):
return true
case <-cancel:
return false
}
}

View File

@@ -1,811 +0,0 @@
package autorest
// Copyright 2017 Microsoft Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
import (
"bytes"
"fmt"
"log"
"net/http"
"os"
"reflect"
"sync"
"testing"
"time"
"github.com/Azure/go-autorest/autorest/mocks"
)
func ExampleSendWithSender() {
r := mocks.NewResponseWithStatus("202 Accepted", http.StatusAccepted)
mocks.SetAcceptedHeaders(r)
client := mocks.NewSender()
client.AppendAndRepeatResponse(r, 10)
logger := log.New(os.Stdout, "autorest: ", 0)
na := NullAuthorizer{}
req, _ := Prepare(&http.Request{},
AsGet(),
WithBaseURL("https://microsoft.com/a/b/c/"),
na.WithAuthorization())
r, _ = SendWithSender(client, req,
WithLogging(logger),
DoErrorIfStatusCode(http.StatusAccepted),
DoCloseIfError(),
DoRetryForAttempts(5, time.Duration(0)))
Respond(r,
ByDiscardingBody(),
ByClosing())
// Output:
// autorest: Sending GET https://microsoft.com/a/b/c/
// autorest: GET https://microsoft.com/a/b/c/ received 202 Accepted
// autorest: Sending GET https://microsoft.com/a/b/c/
// autorest: GET https://microsoft.com/a/b/c/ received 202 Accepted
// autorest: Sending GET https://microsoft.com/a/b/c/
// autorest: GET https://microsoft.com/a/b/c/ received 202 Accepted
// autorest: Sending GET https://microsoft.com/a/b/c/
// autorest: GET https://microsoft.com/a/b/c/ received 202 Accepted
// autorest: Sending GET https://microsoft.com/a/b/c/
// autorest: GET https://microsoft.com/a/b/c/ received 202 Accepted
}
func ExampleDoRetryForAttempts() {
client := mocks.NewSender()
client.SetAndRepeatError(fmt.Errorf("Faux Error"), 10)
// Retry with backoff -- ensure returned Bodies are closed
r, _ := SendWithSender(client, mocks.NewRequest(),
DoCloseIfError(),
DoRetryForAttempts(5, time.Duration(0)))
Respond(r,
ByDiscardingBody(),
ByClosing())
fmt.Printf("Retry stopped after %d attempts", client.Attempts())
// Output: Retry stopped after 5 attempts
}
func ExampleDoErrorIfStatusCode() {
client := mocks.NewSender()
client.AppendAndRepeatResponse(mocks.NewResponseWithStatus("204 NoContent", http.StatusNoContent), 10)
// Chain decorators to retry the request, up to five times, if the status code is 204
r, _ := SendWithSender(client, mocks.NewRequest(),
DoErrorIfStatusCode(http.StatusNoContent),
DoCloseIfError(),
DoRetryForAttempts(5, time.Duration(0)))
Respond(r,
ByDiscardingBody(),
ByClosing())
fmt.Printf("Retry stopped after %d attempts with code %s", client.Attempts(), r.Status)
// Output: Retry stopped after 5 attempts with code 204 NoContent
}
func TestSendWithSenderRunsDecoratorsInOrder(t *testing.T) {
client := mocks.NewSender()
s := ""
r, err := SendWithSender(client, mocks.NewRequest(),
withMessage(&s, "a"),
withMessage(&s, "b"),
withMessage(&s, "c"))
if err != nil {
t.Fatalf("autorest: SendWithSender returned an error (%v)", err)
}
Respond(r,
ByDiscardingBody(),
ByClosing())
if s != "abc" {
t.Fatalf("autorest: SendWithSender invoke decorators out of order; expected 'abc', received '%s'", s)
}
}
func TestCreateSender(t *testing.T) {
f := false
s := CreateSender(
(func() SendDecorator {
return func(s Sender) Sender {
return SenderFunc(func(r *http.Request) (*http.Response, error) {
f = true
return nil, nil
})
}
})())
s.Do(&http.Request{})
if !f {
t.Fatal("autorest: CreateSender failed to apply supplied decorator")
}
}
func TestSend(t *testing.T) {
f := false
Send(&http.Request{},
(func() SendDecorator {
return func(s Sender) Sender {
return SenderFunc(func(r *http.Request) (*http.Response, error) {
f = true
return nil, nil
})
}
})())
if !f {
t.Fatal("autorest: Send failed to apply supplied decorator")
}
}
func TestAfterDelayWaits(t *testing.T) {
client := mocks.NewSender()
d := 2 * time.Second
tt := time.Now()
r, _ := SendWithSender(client, mocks.NewRequest(),
AfterDelay(d))
s := time.Since(tt)
if s < d {
t.Fatal("autorest: AfterDelay failed to wait for at least the specified duration")
}
Respond(r,
ByDiscardingBody(),
ByClosing())
}
func TestAfterDelay_Cancels(t *testing.T) {
client := mocks.NewSender()
cancel := make(chan struct{})
delay := 5 * time.Second
var wg sync.WaitGroup
wg.Add(1)
tt := time.Now()
go func() {
req := mocks.NewRequest()
req.Cancel = cancel
wg.Done()
SendWithSender(client, req,
AfterDelay(delay))
}()
wg.Wait()
close(cancel)
time.Sleep(5 * time.Millisecond)
if time.Since(tt) >= delay {
t.Fatal("autorest: AfterDelay failed to cancel")
}
}
func TestAfterDelayDoesNotWaitTooLong(t *testing.T) {
client := mocks.NewSender()
d := 5 * time.Millisecond
start := time.Now()
r, _ := SendWithSender(client, mocks.NewRequest(),
AfterDelay(d))
if time.Since(start) > (5 * d) {
t.Fatal("autorest: AfterDelay waited too long (exceeded 5 times specified duration)")
}
Respond(r,
ByDiscardingBody(),
ByClosing())
}
func TestAsIs(t *testing.T) {
client := mocks.NewSender()
r1 := mocks.NewResponse()
client.AppendResponse(r1)
r2, err := SendWithSender(client, mocks.NewRequest(),
AsIs())
if err != nil {
t.Fatalf("autorest: AsIs returned an unexpected error (%v)", err)
} else if !reflect.DeepEqual(r1, r2) {
t.Fatalf("autorest: AsIs modified the response -- received %v, expected %v", r2, r1)
}
Respond(r1,
ByDiscardingBody(),
ByClosing())
Respond(r2,
ByDiscardingBody(),
ByClosing())
}
func TestDoCloseIfError(t *testing.T) {
client := mocks.NewSender()
client.AppendResponse(mocks.NewResponseWithStatus("400 BadRequest", http.StatusBadRequest))
r, _ := SendWithSender(client, mocks.NewRequest(),
DoErrorIfStatusCode(http.StatusBadRequest),
DoCloseIfError())
if r.Body.(*mocks.Body).IsOpen() {
t.Fatal("autorest: Expected DoCloseIfError to close response body -- it was left open")
}
Respond(r,
ByDiscardingBody(),
ByClosing())
}
func TestDoCloseIfErrorAcceptsNilResponse(t *testing.T) {
client := mocks.NewSender()
SendWithSender(client, mocks.NewRequest(),
(func() SendDecorator {
return func(s Sender) Sender {
return SenderFunc(func(r *http.Request) (*http.Response, error) {
resp, err := s.Do(r)
if err != nil {
resp.Body.Close()
}
return nil, fmt.Errorf("Faux Error")
})
}
})(),
DoCloseIfError())
}
func TestDoCloseIfErrorAcceptsNilBody(t *testing.T) {
client := mocks.NewSender()
SendWithSender(client, mocks.NewRequest(),
(func() SendDecorator {
return func(s Sender) Sender {
return SenderFunc(func(r *http.Request) (*http.Response, error) {
resp, err := s.Do(r)
if err != nil {
resp.Body.Close()
}
resp.Body = nil
return resp, fmt.Errorf("Faux Error")
})
}
})(),
DoCloseIfError())
}
func TestDoErrorIfStatusCode(t *testing.T) {
client := mocks.NewSender()
client.AppendResponse(mocks.NewResponseWithStatus("400 BadRequest", http.StatusBadRequest))
r, err := SendWithSender(client, mocks.NewRequest(),
DoErrorIfStatusCode(http.StatusBadRequest),
DoCloseIfError())
if err == nil {
t.Fatal("autorest: DoErrorIfStatusCode failed to emit an error for passed code")
}
Respond(r,
ByDiscardingBody(),
ByClosing())
}
func TestDoErrorIfStatusCodeIgnoresStatusCodes(t *testing.T) {
client := mocks.NewSender()
client.AppendResponse(newAcceptedResponse())
r, err := SendWithSender(client, mocks.NewRequest(),
DoErrorIfStatusCode(http.StatusBadRequest),
DoCloseIfError())
if err != nil {
t.Fatal("autorest: DoErrorIfStatusCode failed to ignore a status code")
}
Respond(r,
ByDiscardingBody(),
ByClosing())
}
func TestDoErrorUnlessStatusCode(t *testing.T) {
client := mocks.NewSender()
client.AppendResponse(mocks.NewResponseWithStatus("400 BadRequest", http.StatusBadRequest))
r, err := SendWithSender(client, mocks.NewRequest(),
DoErrorUnlessStatusCode(http.StatusAccepted),
DoCloseIfError())
if err == nil {
t.Fatal("autorest: DoErrorUnlessStatusCode failed to emit an error for an unknown status code")
}
Respond(r,
ByDiscardingBody(),
ByClosing())
}
func TestDoErrorUnlessStatusCodeIgnoresStatusCodes(t *testing.T) {
client := mocks.NewSender()
client.AppendResponse(newAcceptedResponse())
r, err := SendWithSender(client, mocks.NewRequest(),
DoErrorUnlessStatusCode(http.StatusAccepted),
DoCloseIfError())
if err != nil {
t.Fatal("autorest: DoErrorUnlessStatusCode emitted an error for a knonwn status code")
}
Respond(r,
ByDiscardingBody(),
ByClosing())
}
func TestDoRetryForAttemptsStopsAfterSuccess(t *testing.T) {
client := mocks.NewSender()
r, err := SendWithSender(client, mocks.NewRequest(),
DoRetryForAttempts(5, time.Duration(0)))
if client.Attempts() != 1 {
t.Fatalf("autorest: DoRetryForAttempts failed to stop after success -- expected attempts %v, actual %v",
1, client.Attempts())
}
if err != nil {
t.Fatalf("autorest: DoRetryForAttempts returned an unexpected error (%v)", err)
}
Respond(r,
ByDiscardingBody(),
ByClosing())
}
func TestDoRetryForAttemptsStopsAfterAttempts(t *testing.T) {
client := mocks.NewSender()
client.SetAndRepeatError(fmt.Errorf("Faux Error"), 10)
r, err := SendWithSender(client, mocks.NewRequest(),
DoRetryForAttempts(5, time.Duration(0)),
DoCloseIfError())
if err == nil {
t.Fatal("autorest: Mock client failed to emit errors")
}
Respond(r,
ByDiscardingBody(),
ByClosing())
if client.Attempts() != 5 {
t.Fatal("autorest: DoRetryForAttempts failed to stop after specified number of attempts")
}
}
func TestDoRetryForAttemptsReturnsResponse(t *testing.T) {
client := mocks.NewSender()
client.SetError(fmt.Errorf("Faux Error"))
r, err := SendWithSender(client, mocks.NewRequest(),
DoRetryForAttempts(1, time.Duration(0)))
if err == nil {
t.Fatal("autorest: Mock client failed to emit errors")
}
if r == nil {
t.Fatal("autorest: DoRetryForAttempts failed to return the underlying response")
}
Respond(r,
ByDiscardingBody(),
ByClosing())
}
func TestDoRetryForDurationStopsAfterSuccess(t *testing.T) {
client := mocks.NewSender()
r, err := SendWithSender(client, mocks.NewRequest(),
DoRetryForDuration(10*time.Millisecond, time.Duration(0)))
if client.Attempts() != 1 {
t.Fatalf("autorest: DoRetryForDuration failed to stop after success -- expected attempts %v, actual %v",
1, client.Attempts())
}
if err != nil {
t.Fatalf("autorest: DoRetryForDuration returned an unexpected error (%v)", err)
}
Respond(r,
ByDiscardingBody(),
ByClosing())
}
func TestDoRetryForDurationStopsAfterDuration(t *testing.T) {
client := mocks.NewSender()
client.SetAndRepeatError(fmt.Errorf("Faux Error"), -1)
d := 5 * time.Millisecond
start := time.Now()
r, err := SendWithSender(client, mocks.NewRequest(),
DoRetryForDuration(d, time.Duration(0)),
DoCloseIfError())
if err == nil {
t.Fatal("autorest: Mock client failed to emit errors")
}
if time.Since(start) < d {
t.Fatal("autorest: DoRetryForDuration failed stopped too soon")
}
Respond(r,
ByDiscardingBody(),
ByClosing())
}
func TestDoRetryForDurationStopsWithinReason(t *testing.T) {
client := mocks.NewSender()
client.SetAndRepeatError(fmt.Errorf("Faux Error"), -1)
d := 5 * time.Second
start := time.Now()
r, err := SendWithSender(client, mocks.NewRequest(),
DoRetryForDuration(d, time.Duration(0)),
DoCloseIfError())
if err == nil {
t.Fatal("autorest: Mock client failed to emit errors")
}
if time.Since(start) > (5 * d) {
t.Fatal("autorest: DoRetryForDuration failed stopped soon enough (exceeded 5 times specified duration)")
}
Respond(r,
ByDiscardingBody(),
ByClosing())
}
func TestDoRetryForDurationReturnsResponse(t *testing.T) {
client := mocks.NewSender()
client.SetAndRepeatError(fmt.Errorf("Faux Error"), -1)
r, err := SendWithSender(client, mocks.NewRequest(),
DoRetryForDuration(10*time.Millisecond, time.Duration(0)),
DoCloseIfError())
if err == nil {
t.Fatal("autorest: Mock client failed to emit errors")
}
if r == nil {
t.Fatal("autorest: DoRetryForDuration failed to return the underlying response")
}
Respond(r,
ByDiscardingBody(),
ByClosing())
}
func TestDelayForBackoff(t *testing.T) {
d := 2 * time.Second
start := time.Now()
DelayForBackoff(d, 0, nil)
if time.Since(start) < d {
t.Fatal("autorest: DelayForBackoff did not delay as long as expected")
}
}
func TestDelayForBackoff_Cancels(t *testing.T) {
cancel := make(chan struct{})
delay := 5 * time.Second
var wg sync.WaitGroup
wg.Add(1)
start := time.Now()
go func() {
wg.Done()
DelayForBackoff(delay, 0, cancel)
}()
wg.Wait()
close(cancel)
time.Sleep(5 * time.Millisecond)
if time.Since(start) >= delay {
t.Fatal("autorest: DelayForBackoff failed to cancel")
}
}
func TestDelayForBackoffWithinReason(t *testing.T) {
d := 5 * time.Second
maxCoefficient := 2
start := time.Now()
DelayForBackoff(d, 0, nil)
if time.Since(start) > (time.Duration(maxCoefficient) * d) {
t.Fatalf("autorest: DelayForBackoff delayed too long (exceeded %d times the specified duration)", maxCoefficient)
}
}
func TestDoPollForStatusCodes_IgnoresUnspecifiedStatusCodes(t *testing.T) {
client := mocks.NewSender()
r, _ := SendWithSender(client, mocks.NewRequest(),
DoPollForStatusCodes(time.Duration(0), time.Duration(0)))
if client.Attempts() != 1 {
t.Fatalf("autorest: Sender#DoPollForStatusCodes polled for unspecified status code")
}
Respond(r,
ByDiscardingBody(),
ByClosing())
}
func TestDoPollForStatusCodes_PollsForSpecifiedStatusCodes(t *testing.T) {
client := mocks.NewSender()
client.AppendResponse(newAcceptedResponse())
r, _ := SendWithSender(client, mocks.NewRequest(),
DoPollForStatusCodes(time.Millisecond, time.Millisecond, http.StatusAccepted))
if client.Attempts() != 2 {
t.Fatalf("autorest: Sender#DoPollForStatusCodes failed to poll for specified status code")
}
Respond(r,
ByDiscardingBody(),
ByClosing())
}
func TestDoPollForStatusCodes_CanBeCanceled(t *testing.T) {
cancel := make(chan struct{})
delay := 5 * time.Second
r := mocks.NewResponse()
mocks.SetAcceptedHeaders(r)
client := mocks.NewSender()
client.AppendAndRepeatResponse(r, 100)
var wg sync.WaitGroup
wg.Add(1)
start := time.Now()
go func() {
wg.Done()
r, _ := SendWithSender(client, mocks.NewRequest(),
DoPollForStatusCodes(time.Millisecond, time.Millisecond, http.StatusAccepted))
Respond(r,
ByDiscardingBody(),
ByClosing())
}()
wg.Wait()
close(cancel)
time.Sleep(5 * time.Millisecond)
if time.Since(start) >= delay {
t.Fatalf("autorest: Sender#DoPollForStatusCodes failed to cancel")
}
}
func TestDoPollForStatusCodes_ClosesAllNonreturnedResponseBodiesWhenPolling(t *testing.T) {
resp := newAcceptedResponse()
client := mocks.NewSender()
client.AppendAndRepeatResponse(resp, 2)
r, _ := SendWithSender(client, mocks.NewRequest(),
DoPollForStatusCodes(time.Millisecond, time.Millisecond, http.StatusAccepted))
if resp.Body.(*mocks.Body).IsOpen() || resp.Body.(*mocks.Body).CloseAttempts() < 2 {
t.Fatalf("autorest: Sender#DoPollForStatusCodes did not close unreturned response bodies")
}
Respond(r,
ByDiscardingBody(),
ByClosing())
}
func TestDoPollForStatusCodes_LeavesLastResponseBodyOpen(t *testing.T) {
client := mocks.NewSender()
client.AppendResponse(newAcceptedResponse())
r, _ := SendWithSender(client, mocks.NewRequest(),
DoPollForStatusCodes(time.Millisecond, time.Millisecond, http.StatusAccepted))
if !r.Body.(*mocks.Body).IsOpen() {
t.Fatalf("autorest: Sender#DoPollForStatusCodes did not leave open the body of the last response")
}
Respond(r,
ByDiscardingBody(),
ByClosing())
}
func TestDoPollForStatusCodes_StopsPollingAfterAnError(t *testing.T) {
client := mocks.NewSender()
client.AppendAndRepeatResponse(newAcceptedResponse(), 5)
client.SetError(fmt.Errorf("Faux Error"))
client.SetEmitErrorAfter(1)
r, _ := SendWithSender(client, mocks.NewRequest(),
DoPollForStatusCodes(time.Millisecond, time.Millisecond, http.StatusAccepted))
if client.Attempts() > 2 {
t.Fatalf("autorest: Sender#DoPollForStatusCodes failed to stop polling after receiving an error")
}
Respond(r,
ByDiscardingBody(),
ByClosing())
}
func TestDoPollForStatusCodes_ReturnsPollingError(t *testing.T) {
client := mocks.NewSender()
client.AppendAndRepeatResponse(newAcceptedResponse(), 5)
client.SetError(fmt.Errorf("Faux Error"))
client.SetEmitErrorAfter(1)
r, err := SendWithSender(client, mocks.NewRequest(),
DoPollForStatusCodes(time.Millisecond, time.Millisecond, http.StatusAccepted))
if err == nil {
t.Fatalf("autorest: Sender#DoPollForStatusCodes failed to return error from polling")
}
Respond(r,
ByDiscardingBody(),
ByClosing())
}
func TestWithLogging_Logs(t *testing.T) {
buf := &bytes.Buffer{}
logger := log.New(buf, "autorest: ", 0)
client := mocks.NewSender()
r, _ := SendWithSender(client, &http.Request{},
WithLogging(logger))
if buf.String() == "" {
t.Fatal("autorest: Sender#WithLogging failed to log the request")
}
Respond(r,
ByDiscardingBody(),
ByClosing())
}
func TestWithLogging_HandlesMissingResponse(t *testing.T) {
buf := &bytes.Buffer{}
logger := log.New(buf, "autorest: ", 0)
client := mocks.NewSender()
client.AppendResponse(nil)
client.SetError(fmt.Errorf("Faux Error"))
r, err := SendWithSender(client, &http.Request{},
WithLogging(logger))
if r != nil || err == nil {
t.Fatal("autorest: Sender#WithLogging returned a valid response -- expecting nil")
}
if buf.String() == "" {
t.Fatal("autorest: Sender#WithLogging failed to log the request for a nil response")
}
Respond(r,
ByDiscardingBody(),
ByClosing())
}
func TestDoRetryForStatusCodesWithSuccess(t *testing.T) {
client := mocks.NewSender()
client.AppendAndRepeatResponse(mocks.NewResponseWithStatus("408 Request Timeout", http.StatusRequestTimeout), 2)
client.AppendResponse(mocks.NewResponseWithStatus("200 OK", http.StatusOK))
r, _ := SendWithSender(client, mocks.NewRequest(),
DoRetryForStatusCodes(5, time.Duration(2*time.Second), http.StatusRequestTimeout),
)
Respond(r,
ByDiscardingBody(),
ByClosing())
if client.Attempts() != 3 {
t.Fatalf("autorest: Sender#DoRetryForStatusCodes -- Got: StatusCode %v in %v attempts; Want: StatusCode 200 OK in 2 attempts -- ",
r.Status, client.Attempts()-1)
}
}
func TestDoRetryForStatusCodesWithNoSuccess(t *testing.T) {
client := mocks.NewSender()
client.AppendAndRepeatResponse(mocks.NewResponseWithStatus("504 Gateway Timeout", http.StatusGatewayTimeout), 5)
r, _ := SendWithSender(client, mocks.NewRequest(),
DoRetryForStatusCodes(2, time.Duration(2*time.Second), http.StatusGatewayTimeout),
)
Respond(r,
ByDiscardingBody(),
ByClosing())
if client.Attempts() != 3 {
t.Fatalf("autorest: Sender#DoRetryForStatusCodes -- Got: failed stop after %v retry attempts; Want: Stop after 2 retry attempts",
client.Attempts()-1)
}
}
func TestDoRetryForStatusCodes_CodeNotInRetryList(t *testing.T) {
client := mocks.NewSender()
client.AppendAndRepeatResponse(mocks.NewResponseWithStatus("204 No Content", http.StatusNoContent), 1)
r, _ := SendWithSender(client, mocks.NewRequest(),
DoRetryForStatusCodes(6, time.Duration(2*time.Second), http.StatusGatewayTimeout),
)
Respond(r,
ByDiscardingBody(),
ByClosing())
if client.Attempts() != 1 || r.Status != "204 No Content" {
t.Fatalf("autorest: Sender#DoRetryForStatusCodes -- Got: Retry attempts %v for StatusCode %v; Want: 0 attempts for StatusCode 204",
client.Attempts(), r.Status)
}
}
func TestDoRetryForStatusCodes_RequestBodyReadError(t *testing.T) {
client := mocks.NewSender()
client.AppendAndRepeatResponse(mocks.NewResponseWithStatus("204 No Content", http.StatusNoContent), 2)
r, err := SendWithSender(client, mocks.NewRequestWithCloseBody(),
DoRetryForStatusCodes(6, time.Duration(2*time.Second), http.StatusGatewayTimeout),
)
Respond(r,
ByDiscardingBody(),
ByClosing())
if err == nil || client.Attempts() != 0 {
t.Fatalf("autorest: Sender#DoRetryForStatusCodes -- Got: Not failed for request body read error; Want: Failed for body read error - %v", err)
}
}
func newAcceptedResponse() *http.Response {
resp := mocks.NewResponseWithStatus("202 Accepted", http.StatusAccepted)
mocks.SetAcceptedHeaders(resp)
return resp
}
func TestDelayWithRetryAfterWithSuccess(t *testing.T) {
after, retries := 5, 2
totalSecs := after * retries
client := mocks.NewSender()
resp := mocks.NewResponseWithStatus("429 Too many requests", http.StatusTooManyRequests)
mocks.SetResponseHeader(resp, "Retry-After", fmt.Sprintf("%v", after))
client.AppendAndRepeatResponse(resp, retries)
client.AppendResponse(mocks.NewResponseWithStatus("200 OK", http.StatusOK))
d := time.Second * time.Duration(totalSecs)
start := time.Now()
r, _ := SendWithSender(client, mocks.NewRequest(),
DoRetryForStatusCodes(5, time.Duration(time.Second), http.StatusTooManyRequests),
)
if time.Since(start) < d {
t.Fatal("autorest: DelayWithRetryAfter failed stopped too soon")
}
Respond(r,
ByDiscardingBody(),
ByClosing())
if client.Attempts() != 3 {
t.Fatalf("autorest: Sender#DelayWithRetryAfter -- Got: StatusCode %v in %v attempts; Want: StatusCode 200 OK in 2 attempts -- ",
r.Status, client.Attempts()-1)
}
}

View File

@@ -1,147 +0,0 @@
/*
Package to provides helpers to ease working with pointer values of marshalled structures.
*/
package to
// Copyright 2017 Microsoft Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// String returns a string value for the passed string pointer. It returns the empty string if the
// pointer is nil.
func String(s *string) string {
if s != nil {
return *s
}
return ""
}
// StringPtr returns a pointer to the passed string.
func StringPtr(s string) *string {
return &s
}
// StringSlice returns a string slice value for the passed string slice pointer. It returns a nil
// slice if the pointer is nil.
func StringSlice(s *[]string) []string {
if s != nil {
return *s
}
return nil
}
// StringSlicePtr returns a pointer to the passed string slice.
func StringSlicePtr(s []string) *[]string {
return &s
}
// StringMap returns a map of strings built from the map of string pointers. The empty string is
// used for nil pointers.
func StringMap(msp map[string]*string) map[string]string {
ms := make(map[string]string, len(msp))
for k, sp := range msp {
if sp != nil {
ms[k] = *sp
} else {
ms[k] = ""
}
}
return ms
}
// StringMapPtr returns a pointer to a map of string pointers built from the passed map of strings.
func StringMapPtr(ms map[string]string) *map[string]*string {
msp := make(map[string]*string, len(ms))
for k, s := range ms {
msp[k] = StringPtr(s)
}
return &msp
}
// Bool returns a bool value for the passed bool pointer. It returns false if the pointer is nil.
func Bool(b *bool) bool {
if b != nil {
return *b
}
return false
}
// BoolPtr returns a pointer to the passed bool.
func BoolPtr(b bool) *bool {
return &b
}
// Int returns an int value for the passed int pointer. It returns 0 if the pointer is nil.
func Int(i *int) int {
if i != nil {
return *i
}
return 0
}
// IntPtr returns a pointer to the passed int.
func IntPtr(i int) *int {
return &i
}
// Int32 returns an int value for the passed int pointer. It returns 0 if the pointer is nil.
func Int32(i *int32) int32 {
if i != nil {
return *i
}
return 0
}
// Int32Ptr returns a pointer to the passed int32.
func Int32Ptr(i int32) *int32 {
return &i
}
// Int64 returns an int value for the passed int pointer. It returns 0 if the pointer is nil.
func Int64(i *int64) int64 {
if i != nil {
return *i
}
return 0
}
// Int64Ptr returns a pointer to the passed int64.
func Int64Ptr(i int64) *int64 {
return &i
}
// Float32 returns an int value for the passed int pointer. It returns 0.0 if the pointer is nil.
func Float32(i *float32) float32 {
if i != nil {
return *i
}
return 0.0
}
// Float32Ptr returns a pointer to the passed float32.
func Float32Ptr(i float32) *float32 {
return &i
}
// Float64 returns an int value for the passed int pointer. It returns 0.0 if the pointer is nil.
func Float64(i *float64) float64 {
if i != nil {
return *i
}
return 0.0
}
// Float64Ptr returns a pointer to the passed float64.
func Float64Ptr(i float64) *float64 {
return &i
}

View File

@@ -1,234 +0,0 @@
package to
// Copyright 2017 Microsoft Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
import (
"reflect"
"testing"
)
func TestString(t *testing.T) {
v := ""
if String(&v) != v {
t.Fatalf("to: String failed to return the correct string -- expected %v, received %v",
v, String(&v))
}
}
func TestStringHandlesNil(t *testing.T) {
if String(nil) != "" {
t.Fatalf("to: String failed to correctly convert nil -- expected %v, received %v",
"", String(nil))
}
}
func TestStringPtr(t *testing.T) {
v := ""
if *StringPtr(v) != v {
t.Fatalf("to: StringPtr failed to return the correct string -- expected %v, received %v",
v, *StringPtr(v))
}
}
func TestStringSlice(t *testing.T) {
v := []string{}
if out := StringSlice(&v); !reflect.DeepEqual(out, v) {
t.Fatalf("to: StringSlice failed to return the correct slice -- expected %v, received %v",
v, out)
}
}
func TestStringSliceHandlesNil(t *testing.T) {
if out := StringSlice(nil); out != nil {
t.Fatalf("to: StringSlice failed to correctly convert nil -- expected %v, received %v",
nil, out)
}
}
func TestStringSlicePtr(t *testing.T) {
v := []string{"a", "b"}
if out := StringSlicePtr(v); !reflect.DeepEqual(*out, v) {
t.Fatalf("to: StringSlicePtr failed to return the correct slice -- expected %v, received %v",
v, *out)
}
}
func TestStringMap(t *testing.T) {
msp := map[string]*string{"foo": StringPtr("foo"), "bar": StringPtr("bar"), "baz": StringPtr("baz")}
for k, v := range StringMap(msp) {
if *msp[k] != v {
t.Fatalf("to: StringMap incorrectly converted an entry -- expected [%s]%v, received[%s]%v",
k, v, k, *msp[k])
}
}
}
func TestStringMapHandlesNil(t *testing.T) {
msp := map[string]*string{"foo": StringPtr("foo"), "bar": nil, "baz": StringPtr("baz")}
for k, v := range StringMap(msp) {
if msp[k] == nil && v != "" {
t.Fatalf("to: StringMap incorrectly converted a nil entry -- expected [%s]%v, received[%s]%v",
k, v, k, *msp[k])
}
}
}
func TestStringMapPtr(t *testing.T) {
ms := map[string]string{"foo": "foo", "bar": "bar", "baz": "baz"}
for k, msp := range *StringMapPtr(ms) {
if ms[k] != *msp {
t.Fatalf("to: StringMapPtr incorrectly converted an entry -- expected [%s]%v, received[%s]%v",
k, ms[k], k, *msp)
}
}
}
func TestBool(t *testing.T) {
v := false
if Bool(&v) != v {
t.Fatalf("to: Bool failed to return the correct string -- expected %v, received %v",
v, Bool(&v))
}
}
func TestBoolHandlesNil(t *testing.T) {
if Bool(nil) != false {
t.Fatalf("to: Bool failed to correctly convert nil -- expected %v, received %v",
false, Bool(nil))
}
}
func TestBoolPtr(t *testing.T) {
v := false
if *BoolPtr(v) != v {
t.Fatalf("to: BoolPtr failed to return the correct string -- expected %v, received %v",
v, *BoolPtr(v))
}
}
func TestInt(t *testing.T) {
v := 0
if Int(&v) != v {
t.Fatalf("to: Int failed to return the correct string -- expected %v, received %v",
v, Int(&v))
}
}
func TestIntHandlesNil(t *testing.T) {
if Int(nil) != 0 {
t.Fatalf("to: Int failed to correctly convert nil -- expected %v, received %v",
0, Int(nil))
}
}
func TestIntPtr(t *testing.T) {
v := 0
if *IntPtr(v) != v {
t.Fatalf("to: IntPtr failed to return the correct string -- expected %v, received %v",
v, *IntPtr(v))
}
}
func TestInt32(t *testing.T) {
v := int32(0)
if Int32(&v) != v {
t.Fatalf("to: Int32 failed to return the correct string -- expected %v, received %v",
v, Int32(&v))
}
}
func TestInt32HandlesNil(t *testing.T) {
if Int32(nil) != int32(0) {
t.Fatalf("to: Int32 failed to correctly convert nil -- expected %v, received %v",
0, Int32(nil))
}
}
func TestInt32Ptr(t *testing.T) {
v := int32(0)
if *Int32Ptr(v) != v {
t.Fatalf("to: Int32Ptr failed to return the correct string -- expected %v, received %v",
v, *Int32Ptr(v))
}
}
func TestInt64(t *testing.T) {
v := int64(0)
if Int64(&v) != v {
t.Fatalf("to: Int64 failed to return the correct string -- expected %v, received %v",
v, Int64(&v))
}
}
func TestInt64HandlesNil(t *testing.T) {
if Int64(nil) != int64(0) {
t.Fatalf("to: Int64 failed to correctly convert nil -- expected %v, received %v",
0, Int64(nil))
}
}
func TestInt64Ptr(t *testing.T) {
v := int64(0)
if *Int64Ptr(v) != v {
t.Fatalf("to: Int64Ptr failed to return the correct string -- expected %v, received %v",
v, *Int64Ptr(v))
}
}
func TestFloat32(t *testing.T) {
v := float32(0)
if Float32(&v) != v {
t.Fatalf("to: Float32 failed to return the correct string -- expected %v, received %v",
v, Float32(&v))
}
}
func TestFloat32HandlesNil(t *testing.T) {
if Float32(nil) != float32(0) {
t.Fatalf("to: Float32 failed to correctly convert nil -- expected %v, received %v",
0, Float32(nil))
}
}
func TestFloat32Ptr(t *testing.T) {
v := float32(0)
if *Float32Ptr(v) != v {
t.Fatalf("to: Float32Ptr failed to return the correct string -- expected %v, received %v",
v, *Float32Ptr(v))
}
}
func TestFloat64(t *testing.T) {
v := float64(0)
if Float64(&v) != v {
t.Fatalf("to: Float64 failed to return the correct string -- expected %v, received %v",
v, Float64(&v))
}
}
func TestFloat64HandlesNil(t *testing.T) {
if Float64(nil) != float64(0) {
t.Fatalf("to: Float64 failed to correctly convert nil -- expected %v, received %v",
0, Float64(nil))
}
}
func TestFloat64Ptr(t *testing.T) {
v := float64(0)
if *Float64Ptr(v) != v {
t.Fatalf("to: Float64Ptr failed to return the correct string -- expected %v, received %v",
v, *Float64Ptr(v))
}
}

View File

@@ -1,204 +0,0 @@
package autorest
// Copyright 2017 Microsoft Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
import (
"bytes"
"encoding/json"
"encoding/xml"
"fmt"
"io"
"net/http"
"net/url"
"reflect"
"sort"
"strings"
)
// EncodedAs is a series of constants specifying various data encodings
type EncodedAs string
const (
// EncodedAsJSON states that data is encoded as JSON
EncodedAsJSON EncodedAs = "JSON"
// EncodedAsXML states that data is encoded as Xml
EncodedAsXML EncodedAs = "XML"
)
// Decoder defines the decoding method json.Decoder and xml.Decoder share
type Decoder interface {
Decode(v interface{}) error
}
// NewDecoder creates a new decoder appropriate to the passed encoding.
// encodedAs specifies the type of encoding and r supplies the io.Reader containing the
// encoded data.
func NewDecoder(encodedAs EncodedAs, r io.Reader) Decoder {
if encodedAs == EncodedAsJSON {
return json.NewDecoder(r)
} else if encodedAs == EncodedAsXML {
return xml.NewDecoder(r)
}
return nil
}
// CopyAndDecode decodes the data from the passed io.Reader while making a copy. Having a copy
// is especially useful if there is a chance the data will fail to decode.
// encodedAs specifies the expected encoding, r provides the io.Reader to the data, and v
// is the decoding destination.
func CopyAndDecode(encodedAs EncodedAs, r io.Reader, v interface{}) (bytes.Buffer, error) {
b := bytes.Buffer{}
return b, NewDecoder(encodedAs, io.TeeReader(r, &b)).Decode(v)
}
// TeeReadCloser returns a ReadCloser that writes to w what it reads from rc.
// It utilizes io.TeeReader to copy the data read and has the same behavior when reading.
// Further, when it is closed, it ensures that rc is closed as well.
func TeeReadCloser(rc io.ReadCloser, w io.Writer) io.ReadCloser {
return &teeReadCloser{rc, io.TeeReader(rc, w)}
}
type teeReadCloser struct {
rc io.ReadCloser
r io.Reader
}
func (t *teeReadCloser) Read(p []byte) (int, error) {
return t.r.Read(p)
}
func (t *teeReadCloser) Close() error {
return t.rc.Close()
}
func containsInt(ints []int, n int) bool {
for _, i := range ints {
if i == n {
return true
}
}
return false
}
func escapeValueStrings(m map[string]string) map[string]string {
for key, value := range m {
m[key] = url.QueryEscape(value)
}
return m
}
func ensureValueStrings(mapOfInterface map[string]interface{}) map[string]string {
mapOfStrings := make(map[string]string)
for key, value := range mapOfInterface {
mapOfStrings[key] = ensureValueString(value)
}
return mapOfStrings
}
func ensureValueString(value interface{}) string {
if value == nil {
return ""
}
switch v := value.(type) {
case string:
return v
case []byte:
return string(v)
default:
return fmt.Sprintf("%v", v)
}
}
// MapToValues method converts map[string]interface{} to url.Values.
func MapToValues(m map[string]interface{}) url.Values {
v := url.Values{}
for key, value := range m {
x := reflect.ValueOf(value)
if x.Kind() == reflect.Array || x.Kind() == reflect.Slice {
for i := 0; i < x.Len(); i++ {
v.Add(key, ensureValueString(x.Index(i)))
}
} else {
v.Add(key, ensureValueString(value))
}
}
return v
}
// String method converts interface v to string. If interface is a list, it
// joins list elements using separator.
func String(v interface{}, sep ...string) string {
if len(sep) > 0 {
return ensureValueString(strings.Join(v.([]string), sep[0]))
}
return ensureValueString(v)
}
// Encode method encodes url path and query parameters.
func Encode(location string, v interface{}, sep ...string) string {
s := String(v, sep...)
switch strings.ToLower(location) {
case "path":
return pathEscape(s)
case "query":
return queryEscape(s)
default:
return s
}
}
func pathEscape(s string) string {
return strings.Replace(url.QueryEscape(s), "+", "%20", -1)
}
func queryEscape(s string) string {
return url.QueryEscape(s)
}
// This method is same as Encode() method of "net/url" go package,
// except it does not encode the query parameters because they
// already come encoded. It formats values map in query format (bar=foo&a=b).
func createQuery(v url.Values) string {
var buf bytes.Buffer
keys := make([]string, 0, len(v))
for k := range v {
keys = append(keys, k)
}
sort.Strings(keys)
for _, k := range keys {
vs := v[k]
prefix := url.QueryEscape(k) + "="
for _, v := range vs {
if buf.Len() > 0 {
buf.WriteByte('&')
}
buf.WriteString(prefix)
buf.WriteString(v)
}
}
return buf.String()
}
// ChangeToGet turns the specified http.Request into a GET (it assumes it wasn't).
// This is mainly useful for long-running operations that use the Azure-AsyncOperation
// header, so we change the initial PUT into a GET to retrieve the final result.
func ChangeToGet(req *http.Request) *http.Request {
req.Method = "GET"
req.Body = nil
req.ContentLength = 0
req.Header.Del("Content-Length")
return req
}

View File

@@ -1,382 +0,0 @@
package autorest
// Copyright 2017 Microsoft Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
import (
"bytes"
"encoding/json"
"encoding/xml"
"fmt"
"net/http"
"net/url"
"reflect"
"sort"
"strings"
"testing"
"github.com/Azure/go-autorest/autorest/mocks"
)
const (
jsonT = `
{
"name":"Rob Pike",
"age":42
}`
xmlT = `<?xml version="1.0" encoding="UTF-8"?>
<Person>
<Name>Rob Pike</Name>
<Age>42</Age>
</Person>`
)
func TestNewDecoderCreatesJSONDecoder(t *testing.T) {
d := NewDecoder(EncodedAsJSON, strings.NewReader(jsonT))
_, ok := d.(*json.Decoder)
if d == nil || !ok {
t.Fatal("autorest: NewDecoder failed to create a JSON decoder when requested")
}
}
func TestNewDecoderCreatesXMLDecoder(t *testing.T) {
d := NewDecoder(EncodedAsXML, strings.NewReader(xmlT))
_, ok := d.(*xml.Decoder)
if d == nil || !ok {
t.Fatal("autorest: NewDecoder failed to create an XML decoder when requested")
}
}
func TestNewDecoderReturnsNilForUnknownEncoding(t *testing.T) {
d := NewDecoder("unknown", strings.NewReader(xmlT))
if d != nil {
t.Fatal("autorest: NewDecoder created a decoder for an unknown encoding")
}
}
func TestCopyAndDecodeDecodesJSON(t *testing.T) {
_, err := CopyAndDecode(EncodedAsJSON, strings.NewReader(jsonT), &mocks.T{})
if err != nil {
t.Fatalf("autorest: CopyAndDecode returned an error with valid JSON - %v", err)
}
}
func TestCopyAndDecodeDecodesXML(t *testing.T) {
_, err := CopyAndDecode(EncodedAsXML, strings.NewReader(xmlT), &mocks.T{})
if err != nil {
t.Fatalf("autorest: CopyAndDecode returned an error with valid XML - %v", err)
}
}
func TestCopyAndDecodeReturnsJSONDecodingErrors(t *testing.T) {
_, err := CopyAndDecode(EncodedAsJSON, strings.NewReader(jsonT[0:len(jsonT)-2]), &mocks.T{})
if err == nil {
t.Fatalf("autorest: CopyAndDecode failed to return an error with invalid JSON")
}
}
func TestCopyAndDecodeReturnsXMLDecodingErrors(t *testing.T) {
_, err := CopyAndDecode(EncodedAsXML, strings.NewReader(xmlT[0:len(xmlT)-2]), &mocks.T{})
if err == nil {
t.Fatalf("autorest: CopyAndDecode failed to return an error with invalid XML")
}
}
func TestCopyAndDecodeAlwaysReturnsACopy(t *testing.T) {
b, _ := CopyAndDecode(EncodedAsJSON, strings.NewReader(jsonT), &mocks.T{})
if b.String() != jsonT {
t.Fatalf("autorest: CopyAndDecode failed to return a valid copy of the data - %v", b.String())
}
}
func TestTeeReadCloser_Copies(t *testing.T) {
v := &mocks.T{}
r := mocks.NewResponseWithContent(jsonT)
b := &bytes.Buffer{}
r.Body = TeeReadCloser(r.Body, b)
err := Respond(r,
ByUnmarshallingJSON(v),
ByClosing())
if err != nil {
t.Fatalf("autorest: TeeReadCloser returned an unexpected error -- %v", err)
}
if b.String() != jsonT {
t.Fatalf("autorest: TeeReadCloser failed to copy the bytes read")
}
}
func TestTeeReadCloser_PassesReadErrors(t *testing.T) {
v := &mocks.T{}
r := mocks.NewResponseWithContent(jsonT)
r.Body.(*mocks.Body).Close()
r.Body = TeeReadCloser(r.Body, &bytes.Buffer{})
err := Respond(r,
ByUnmarshallingJSON(v),
ByClosing())
if err == nil {
t.Fatalf("autorest: TeeReadCloser failed to return the expected error")
}
}
func TestTeeReadCloser_ClosesWrappedReader(t *testing.T) {
v := &mocks.T{}
r := mocks.NewResponseWithContent(jsonT)
b := r.Body.(*mocks.Body)
r.Body = TeeReadCloser(r.Body, &bytes.Buffer{})
err := Respond(r,
ByUnmarshallingJSON(v),
ByClosing())
if err != nil {
t.Fatalf("autorest: TeeReadCloser returned an unexpected error -- %v", err)
}
if b.IsOpen() {
t.Fatalf("autorest: TeeReadCloser failed to close the nested io.ReadCloser")
}
}
func TestContainsIntFindsValue(t *testing.T) {
ints := []int{0, 1, 2, 3, 4, 5, 6, 7, 8, 9}
v := 5
if !containsInt(ints, v) {
t.Fatalf("autorest: containsInt failed to find %v in %v", v, ints)
}
}
func TestContainsIntDoesNotFindValue(t *testing.T) {
ints := []int{0, 1, 2, 3, 4, 5, 6, 7, 8, 9}
v := 42
if containsInt(ints, v) {
t.Fatalf("autorest: containsInt unexpectedly found %v in %v", v, ints)
}
}
func TestContainsIntAcceptsEmptyList(t *testing.T) {
ints := make([]int, 10)
if containsInt(ints, 42) {
t.Fatalf("autorest: containsInt failed to handle an empty list")
}
}
func TestContainsIntAcceptsNilList(t *testing.T) {
var ints []int
if containsInt(ints, 42) {
t.Fatalf("autorest: containsInt failed to handle an nil list")
}
}
func TestEscapeStrings(t *testing.T) {
m := map[string]string{
"string": "a long string with = odd characters",
"int": "42",
"nil": "",
}
r := map[string]string{
"string": "a+long+string+with+%3D+odd+characters",
"int": "42",
"nil": "",
}
v := escapeValueStrings(m)
if !reflect.DeepEqual(v, r) {
t.Fatalf("autorest: ensureValueStrings returned %v\n", v)
}
}
func TestEnsureStrings(t *testing.T) {
m := map[string]interface{}{
"string": "string",
"int": 42,
"nil": nil,
"bytes": []byte{255, 254, 253},
}
r := map[string]string{
"string": "string",
"int": "42",
"nil": "",
"bytes": string([]byte{255, 254, 253}),
}
v := ensureValueStrings(m)
if !reflect.DeepEqual(v, r) {
t.Fatalf("autorest: ensureValueStrings returned %v\n", v)
}
}
func ExampleString() {
m := []string{
"string1",
"string2",
"string3",
}
fmt.Println(String(m, ","))
// Output: string1,string2,string3
}
func TestStringWithValidString(t *testing.T) {
i := 123
if String(i) != "123" {
t.Fatal("autorest: String method failed to convert integer 123 to string")
}
}
func TestEncodeWithValidPath(t *testing.T) {
s := Encode("Path", "Hello Gopher")
if s != "Hello%20Gopher" {
t.Fatalf("autorest: Encode method failed for valid path encoding. Got: %v; Want: %v", s, "Hello%20Gopher")
}
}
func TestEncodeWithValidQuery(t *testing.T) {
s := Encode("Query", "Hello Gopher")
if s != "Hello+Gopher" {
t.Fatalf("autorest: Encode method failed for valid query encoding. Got: '%v'; Want: 'Hello+Gopher'", s)
}
}
func TestEncodeWithValidNotPathQuery(t *testing.T) {
s := Encode("Host", "Hello Gopher")
if s != "Hello Gopher" {
t.Fatalf("autorest: Encode method failed for parameter not query or path. Got: '%v'; Want: 'Hello Gopher'", s)
}
}
func TestMapToValues(t *testing.T) {
m := map[string]interface{}{
"a": "a",
"b": 2,
}
v := url.Values{}
v.Add("a", "a")
v.Add("b", "2")
if !isEqual(v, MapToValues(m)) {
t.Fatalf("autorest: MapToValues method failed to return correct values - expected(%v) got(%v)", v, MapToValues(m))
}
}
func TestMapToValuesWithArrayValues(t *testing.T) {
m := map[string]interface{}{
"a": []string{"a", "b"},
"b": 2,
"c": []int{3, 4},
}
v := url.Values{}
v.Add("a", "a")
v.Add("a", "b")
v.Add("b", "2")
v.Add("c", "3")
v.Add("c", "4")
if !isEqual(v, MapToValues(m)) {
t.Fatalf("autorest: MapToValues method failed to return correct values - expected(%v) got(%v)", v, MapToValues(m))
}
}
func isEqual(v, u url.Values) bool {
for key, value := range v {
if len(u[key]) == 0 {
return false
}
sort.Strings(value)
sort.Strings(u[key])
for i := range value {
if value[i] != u[key][i] {
return false
}
}
u.Del(key)
}
if len(u) > 0 {
return false
}
return true
}
func doEnsureBodyClosed(t *testing.T) SendDecorator {
return func(s Sender) Sender {
return SenderFunc(func(r *http.Request) (*http.Response, error) {
resp, err := s.Do(r)
if resp != nil && resp.Body != nil && resp.Body.(*mocks.Body).IsOpen() {
t.Fatal("autorest: Expected Body to be closed -- it was left open")
}
return resp, err
})
}
}
type mockAuthorizer struct{}
func (ma mockAuthorizer) WithAuthorization() PrepareDecorator {
return WithHeader(headerAuthorization, mocks.TestAuthorizationHeader)
}
type mockFailingAuthorizer struct{}
func (mfa mockFailingAuthorizer) WithAuthorization() PrepareDecorator {
return func(p Preparer) Preparer {
return PreparerFunc(func(r *http.Request) (*http.Request, error) {
return r, fmt.Errorf("ERROR: mockFailingAuthorizer returned expected error")
})
}
}
type mockInspector struct {
wasInvoked bool
}
func (mi *mockInspector) WithInspection() PrepareDecorator {
return func(p Preparer) Preparer {
return PreparerFunc(func(r *http.Request) (*http.Request, error) {
mi.wasInvoked = true
return p.Prepare(r)
})
}
}
func (mi *mockInspector) ByInspecting() RespondDecorator {
return func(r Responder) Responder {
return ResponderFunc(func(resp *http.Response) error {
mi.wasInvoked = true
return r.Respond(resp)
})
}
}
func withMessage(output *string, msg string) SendDecorator {
return func(s Sender) Sender {
return SenderFunc(func(r *http.Request) (*http.Response, error) {
resp, err := s.Do(r)
if err == nil {
*output += msg
}
return resp, err
})
}
}
func withErrorRespondDecorator(e *error) RespondDecorator {
return func(r Responder) Responder {
return ResponderFunc(func(resp *http.Response) error {
err := r.Respond(resp)
if err != nil {
return err
}
*e = fmt.Errorf("autorest: Faux Respond Error")
return *e
})
}
}

View File

@@ -1,56 +0,0 @@
package utils
// Copyright 2017 Microsoft Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
import (
"fmt"
"os"
"github.com/Azure/go-autorest/autorest"
"github.com/Azure/go-autorest/autorest/adal"
"github.com/Azure/go-autorest/autorest/azure"
)
// GetAuthorizer gets an Azure Service Principal authorizer.
// This func assumes "AZURE_TENANT_ID", "AZURE_CLIENT_ID",
// "AZURE_CLIENT_SECRET" are set as environment variables.
func GetAuthorizer(env azure.Environment) (*autorest.BearerAuthorizer, error) {
tenantID := GetEnvVarOrExit("AZURE_TENANT_ID")
oauthConfig, err := adal.NewOAuthConfig(env.ActiveDirectoryEndpoint, tenantID)
if err != nil {
return nil, err
}
clientID := GetEnvVarOrExit("AZURE_CLIENT_ID")
clientSecret := GetEnvVarOrExit("AZURE_CLIENT_SECRET")
spToken, err := adal.NewServicePrincipalToken(*oauthConfig, clientID, clientSecret, env.ResourceManagerEndpoint)
if err != nil {
return nil, err
}
return autorest.NewBearerAuthorizer(spToken), nil
}
// GetEnvVarOrExit returns the value of specified environment variable or terminates if it's not defined.
func GetEnvVarOrExit(varName string) string {
value := os.Getenv(varName)
if value == "" {
fmt.Printf("Missing environment variable %s\n", varName)
os.Exit(1)
}
return value
}

View File

@@ -1,32 +0,0 @@
package utils
// Copyright 2017 Microsoft Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
import (
"bytes"
"os/exec"
)
// GetCommit returns git HEAD (short)
func GetCommit() string {
cmd := exec.Command("git", "rev-parse", "HEAD")
var out bytes.Buffer
cmd.Stdout = &out
err := cmd.Run()
if err != nil {
return ""
}
return string(out.Bytes()[:7])
}

View File

@@ -1,395 +0,0 @@
/*
Package validation provides methods for validating parameter value using reflection.
*/
package validation
// Copyright 2017 Microsoft Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
import (
"fmt"
"reflect"
"regexp"
"strings"
)
// Constraint stores constraint name, target field name
// Rule and chain validations.
type Constraint struct {
// Target field name for validation.
Target string
// Constraint name e.g. minLength, MaxLength, Pattern, etc.
Name string
// Rule for constraint e.g. greater than 10, less than 5 etc.
Rule interface{}
// Chain Validations for struct type
Chain []Constraint
}
// Validation stores parameter-wise validation.
type Validation struct {
TargetValue interface{}
Constraints []Constraint
}
// Constraint list
const (
Empty = "Empty"
Null = "Null"
ReadOnly = "ReadOnly"
Pattern = "Pattern"
MaxLength = "MaxLength"
MinLength = "MinLength"
MaxItems = "MaxItems"
MinItems = "MinItems"
MultipleOf = "MultipleOf"
UniqueItems = "UniqueItems"
InclusiveMaximum = "InclusiveMaximum"
ExclusiveMaximum = "ExclusiveMaximum"
ExclusiveMinimum = "ExclusiveMinimum"
InclusiveMinimum = "InclusiveMinimum"
)
// Validate method validates constraints on parameter
// passed in validation array.
func Validate(m []Validation) error {
for _, item := range m {
v := reflect.ValueOf(item.TargetValue)
for _, constraint := range item.Constraints {
var err error
switch v.Kind() {
case reflect.Ptr:
err = validatePtr(v, constraint)
case reflect.String:
err = validateString(v, constraint)
case reflect.Struct:
err = validateStruct(v, constraint)
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
err = validateInt(v, constraint)
case reflect.Float32, reflect.Float64:
err = validateFloat(v, constraint)
case reflect.Array, reflect.Slice, reflect.Map:
err = validateArrayMap(v, constraint)
default:
err = createError(v, constraint, fmt.Sprintf("unknown type %v", v.Kind()))
}
if err != nil {
return err
}
}
}
return nil
}
func validateStruct(x reflect.Value, v Constraint, name ...string) error {
//Get field name from target name which is in format a.b.c
s := strings.Split(v.Target, ".")
f := x.FieldByName(s[len(s)-1])
if isZero(f) {
return createError(x, v, fmt.Sprintf("field %q doesn't exist", v.Target))
}
return Validate([]Validation{
{
TargetValue: getInterfaceValue(f),
Constraints: []Constraint{v},
},
})
}
func validatePtr(x reflect.Value, v Constraint) error {
if v.Name == ReadOnly {
if !x.IsNil() {
return createError(x.Elem(), v, "readonly parameter; must send as nil or empty in request")
}
return nil
}
if x.IsNil() {
return checkNil(x, v)
}
if v.Chain != nil {
return Validate([]Validation{
{
TargetValue: getInterfaceValue(x.Elem()),
Constraints: v.Chain,
},
})
}
return nil
}
func validateInt(x reflect.Value, v Constraint) error {
i := x.Int()
r, ok := v.Rule.(int)
if !ok {
return createError(x, v, fmt.Sprintf("rule must be integer value for %v constraint; got: %v", v.Name, v.Rule))
}
switch v.Name {
case MultipleOf:
if i%int64(r) != 0 {
return createError(x, v, fmt.Sprintf("value must be a multiple of %v", r))
}
case ExclusiveMinimum:
if i <= int64(r) {
return createError(x, v, fmt.Sprintf("value must be greater than %v", r))
}
case ExclusiveMaximum:
if i >= int64(r) {
return createError(x, v, fmt.Sprintf("value must be less than %v", r))
}
case InclusiveMinimum:
if i < int64(r) {
return createError(x, v, fmt.Sprintf("value must be greater than or equal to %v", r))
}
case InclusiveMaximum:
if i > int64(r) {
return createError(x, v, fmt.Sprintf("value must be less than or equal to %v", r))
}
default:
return createError(x, v, fmt.Sprintf("constraint %v is not applicable for type integer", v.Name))
}
return nil
}
func validateFloat(x reflect.Value, v Constraint) error {
f := x.Float()
r, ok := v.Rule.(float64)
if !ok {
return createError(x, v, fmt.Sprintf("rule must be float value for %v constraint; got: %v", v.Name, v.Rule))
}
switch v.Name {
case ExclusiveMinimum:
if f <= r {
return createError(x, v, fmt.Sprintf("value must be greater than %v", r))
}
case ExclusiveMaximum:
if f >= r {
return createError(x, v, fmt.Sprintf("value must be less than %v", r))
}
case InclusiveMinimum:
if f < r {
return createError(x, v, fmt.Sprintf("value must be greater than or equal to %v", r))
}
case InclusiveMaximum:
if f > r {
return createError(x, v, fmt.Sprintf("value must be less than or equal to %v", r))
}
default:
return createError(x, v, fmt.Sprintf("constraint %s is not applicable for type float", v.Name))
}
return nil
}
func validateString(x reflect.Value, v Constraint) error {
s := x.String()
switch v.Name {
case Empty:
if len(s) == 0 {
return checkEmpty(x, v)
}
case Pattern:
reg, err := regexp.Compile(v.Rule.(string))
if err != nil {
return createError(x, v, err.Error())
}
if !reg.MatchString(s) {
return createError(x, v, fmt.Sprintf("value doesn't match pattern %v", v.Rule))
}
case MaxLength:
if _, ok := v.Rule.(int); !ok {
return createError(x, v, fmt.Sprintf("rule must be integer value for %v constraint; got: %v", v.Name, v.Rule))
}
if len(s) > v.Rule.(int) {
return createError(x, v, fmt.Sprintf("value length must be less than or equal to %v", v.Rule))
}
case MinLength:
if _, ok := v.Rule.(int); !ok {
return createError(x, v, fmt.Sprintf("rule must be integer value for %v constraint; got: %v", v.Name, v.Rule))
}
if len(s) < v.Rule.(int) {
return createError(x, v, fmt.Sprintf("value length must be greater than or equal to %v", v.Rule))
}
case ReadOnly:
if len(s) > 0 {
return createError(reflect.ValueOf(s), v, "readonly parameter; must send as nil or empty in request")
}
default:
return createError(x, v, fmt.Sprintf("constraint %s is not applicable to string type", v.Name))
}
if v.Chain != nil {
return Validate([]Validation{
{
TargetValue: getInterfaceValue(x),
Constraints: v.Chain,
},
})
}
return nil
}
func validateArrayMap(x reflect.Value, v Constraint) error {
switch v.Name {
case Null:
if x.IsNil() {
return checkNil(x, v)
}
case Empty:
if x.IsNil() || x.Len() == 0 {
return checkEmpty(x, v)
}
case MaxItems:
if _, ok := v.Rule.(int); !ok {
return createError(x, v, fmt.Sprintf("rule must be integer for %v constraint; got: %v", v.Name, v.Rule))
}
if x.Len() > v.Rule.(int) {
return createError(x, v, fmt.Sprintf("maximum item limit is %v; got: %v", v.Rule, x.Len()))
}
case MinItems:
if _, ok := v.Rule.(int); !ok {
return createError(x, v, fmt.Sprintf("rule must be integer for %v constraint; got: %v", v.Name, v.Rule))
}
if x.Len() < v.Rule.(int) {
return createError(x, v, fmt.Sprintf("minimum item limit is %v; got: %v", v.Rule, x.Len()))
}
case UniqueItems:
if x.Kind() == reflect.Array || x.Kind() == reflect.Slice {
if !checkForUniqueInArray(x) {
return createError(x, v, fmt.Sprintf("all items in parameter %q must be unique; got:%v", v.Target, x))
}
} else if x.Kind() == reflect.Map {
if !checkForUniqueInMap(x) {
return createError(x, v, fmt.Sprintf("all items in parameter %q must be unique; got:%v", v.Target, x))
}
} else {
return createError(x, v, fmt.Sprintf("type must be array, slice or map for constraint %v; got: %v", v.Name, x.Kind()))
}
case ReadOnly:
if x.Len() != 0 {
return createError(x, v, "readonly parameter; must send as nil or empty in request")
}
case Pattern:
reg, err := regexp.Compile(v.Rule.(string))
if err != nil {
return createError(x, v, err.Error())
}
keys := x.MapKeys()
for _, k := range keys {
if !reg.MatchString(k.String()) {
return createError(k, v, fmt.Sprintf("map key doesn't match pattern %v", v.Rule))
}
}
default:
return createError(x, v, fmt.Sprintf("constraint %v is not applicable to array, slice and map type", v.Name))
}
if v.Chain != nil {
return Validate([]Validation{
{
TargetValue: getInterfaceValue(x),
Constraints: v.Chain,
},
})
}
return nil
}
func checkNil(x reflect.Value, v Constraint) error {
if _, ok := v.Rule.(bool); !ok {
return createError(x, v, fmt.Sprintf("rule must be bool value for %v constraint; got: %v", v.Name, v.Rule))
}
if v.Rule.(bool) {
return createError(x, v, "value can not be null; required parameter")
}
return nil
}
func checkEmpty(x reflect.Value, v Constraint) error {
if _, ok := v.Rule.(bool); !ok {
return createError(x, v, fmt.Sprintf("rule must be bool value for %v constraint; got: %v", v.Name, v.Rule))
}
if v.Rule.(bool) {
return createError(x, v, "value can not be null or empty; required parameter")
}
return nil
}
func checkForUniqueInArray(x reflect.Value) bool {
if x == reflect.Zero(reflect.TypeOf(x)) || x.Len() == 0 {
return false
}
arrOfInterface := make([]interface{}, x.Len())
for i := 0; i < x.Len(); i++ {
arrOfInterface[i] = x.Index(i).Interface()
}
m := make(map[interface{}]bool)
for _, val := range arrOfInterface {
if m[val] {
return false
}
m[val] = true
}
return true
}
func checkForUniqueInMap(x reflect.Value) bool {
if x == reflect.Zero(reflect.TypeOf(x)) || x.Len() == 0 {
return false
}
mapOfInterface := make(map[interface{}]interface{}, x.Len())
keys := x.MapKeys()
for _, k := range keys {
mapOfInterface[k.Interface()] = x.MapIndex(k).Interface()
}
m := make(map[interface{}]bool)
for _, val := range mapOfInterface {
if m[val] {
return false
}
m[val] = true
}
return true
}
func getInterfaceValue(x reflect.Value) interface{} {
if x.Kind() == reflect.Invalid {
return nil
}
return x.Interface()
}
func isZero(x interface{}) bool {
return x == reflect.Zero(reflect.TypeOf(x)).Interface()
}
func createError(x reflect.Value, v Constraint, err string) error {
return fmt.Errorf("autorest/validation: validation failed: parameter=%s constraint=%s value=%#v details: %s",
v.Target, v.Name, getInterfaceValue(x), err)
}
// NewErrorWithValidationError appends package type and method name in
// validation error.
func NewErrorWithValidationError(err error, packageType, method string) error {
return fmt.Errorf("%s#%s: Invalid input: %v", packageType, method, err)
}

File diff suppressed because it is too large Load Diff

View File

@@ -1,49 +0,0 @@
package autorest
// Copyright 2017 Microsoft Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
import (
"bytes"
"fmt"
"strings"
"sync"
)
const (
major = 8
minor = 0
patch = 0
tag = ""
)
var once sync.Once
var version string
// Version returns the semantic version (see http://semver.org).
func Version() string {
once.Do(func() {
semver := fmt.Sprintf("%d.%d.%d", major, minor, patch)
verBuilder := bytes.NewBufferString(semver)
if tag != "" && tag != "-" {
updated := strings.TrimPrefix(tag, "-")
_, err := verBuilder.WriteString("-" + updated)
if err == nil {
verBuilder = bytes.NewBufferString(semver)
}
}
version = verBuilder.String()
})
return version
}

View File

@@ -1,44 +0,0 @@
hash: 6e0121d946623e7e609280b1b18627e1c8a767fdece54cb97c4447c1167cbc46
updated: 2017-08-31T13:58:01.034822883+01:00
imports:
- name: github.com/dgrijalva/jwt-go
version: 2268707a8f0843315e2004ee4f1d021dc08baedf
subpackages:
- .
- name: github.com/dimchansky/utfbom
version: 6c6132ff69f0f6c088739067407b5d32c52e1d0f
- name: github.com/mitchellh/go-homedir
version: b8bc1bf767474819792c23f32d8286a45736f1c6
- name: golang.org/x/crypto
version: 81e90905daefcd6fd217b62423c0908922eadb30
repo: https://github.com/golang/crypto.git
vcs: git
subpackages:
- pkcs12
- pkcs12/internal/rc2
- name: golang.org/x/net
version: 66aacef3dd8a676686c7ae3716979581e8b03c47
repo: https://github.com/golang/net.git
vcs: git
subpackages:
- .
- name: golang.org/x/text
version: 21e35d45962262c8ee80f6cb048dcf95ad0e9d79
repo: https://github.com/golang/text.git
vcs: git
subpackages:
- .
testImports:
- name: github.com/davecgh/go-spew
version: 04cdfd42973bb9c8589fd6a731800cf222fde1a9
subpackages:
- spew
- name: github.com/pmezard/go-difflib
version: d8ed2627bdf02c080bf22230dbb337003b7aba2d
subpackages:
- difflib
- name: github.com/stretchr/testify
version: 890a5c3458b43e6104ff5da8dfa139d013d77544
subpackages:
- assert
- require

View File

@@ -1,22 +0,0 @@
package: github.com/Azure/go-autorest
import:
- package: github.com/dgrijalva/jwt-go
subpackages:
- .
- package: golang.org/x/crypto
repo: https://github.com/golang/crypto.git
vcs: git
subpackages:
- pkcs12
- package: golang.org/x/net
repo: https://github.com/golang/net.git
vcs: git
subpackages:
- .
- package: golang.org/x/text
repo: https://github.com/golang/text.git
vcs: git
subpackages:
- .
- package: github.com/mitchellh/go-homedir
- package: github.com/dimchansky/utfbom

Binary file not shown.

Binary file not shown.

View File

@@ -1,12 +0,0 @@
{
"clientId": "client-id-123",
"clientSecret": "client-secret-456",
"subscriptionId": "sub-id-789",
"tenantId": "tenant-id-123",
"activeDirectoryEndpointUrl": "https://login.microsoftonline.com",
"resourceManagerEndpointUrl": "https://management.azure.com/",
"activeDirectoryGraphResourceId": "https://graph.windows.net/",
"sqlManagementEndpointUrl": "https://management.core.windows.net:8443/",
"galleryEndpointUrl": "https://gallery.azure.com/",
"managementEndpointUrl": "https://management.core.windows.net/"
}

View File

@@ -1,5 +0,0 @@
TAGS
tags
.*.swp
tomlcheck/tomlcheck
toml.test

View File

@@ -1,15 +0,0 @@
language: go
go:
- 1.1
- 1.2
- 1.3
- 1.4
- 1.5
- 1.6
- tip
install:
- go install ./...
- go get github.com/BurntSushi/toml-test
script:
- export PATH="$PATH:$HOME/gopath/bin"
- make test

View File

@@ -1,3 +0,0 @@
Compatible with TOML version
[v0.4.0](https://github.com/toml-lang/toml/blob/v0.4.0/versions/en/toml-v0.4.0.md)

View File

@@ -1,19 +0,0 @@
install:
go install ./...
test: install
go test -v
toml-test toml-test-decoder
toml-test -encoder toml-test-encoder
fmt:
gofmt -w *.go */*.go
colcheck *.go */*.go
tags:
find ./ -name '*.go' -print0 | xargs -0 gotags > TAGS
push:
git push origin master
git push github master

View File

@@ -1,218 +0,0 @@
## TOML parser and encoder for Go with reflection
TOML stands for Tom's Obvious, Minimal Language. This Go package provides a
reflection interface similar to Go's standard library `json` and `xml`
packages. This package also supports the `encoding.TextUnmarshaler` and
`encoding.TextMarshaler` interfaces so that you can define custom data
representations. (There is an example of this below.)
Spec: https://github.com/toml-lang/toml
Compatible with TOML version
[v0.4.0](https://github.com/toml-lang/toml/blob/master/versions/en/toml-v0.4.0.md)
Documentation: https://godoc.org/github.com/BurntSushi/toml
Installation:
```bash
go get github.com/BurntSushi/toml
```
Try the toml validator:
```bash
go get github.com/BurntSushi/toml/cmd/tomlv
tomlv some-toml-file.toml
```
[![Build Status](https://travis-ci.org/BurntSushi/toml.svg?branch=master)](https://travis-ci.org/BurntSushi/toml) [![GoDoc](https://godoc.org/github.com/BurntSushi/toml?status.svg)](https://godoc.org/github.com/BurntSushi/toml)
### Testing
This package passes all tests in
[toml-test](https://github.com/BurntSushi/toml-test) for both the decoder
and the encoder.
### Examples
This package works similarly to how the Go standard library handles `XML`
and `JSON`. Namely, data is loaded into Go values via reflection.
For the simplest example, consider some TOML file as just a list of keys
and values:
```toml
Age = 25
Cats = [ "Cauchy", "Plato" ]
Pi = 3.14
Perfection = [ 6, 28, 496, 8128 ]
DOB = 1987-07-05T05:45:00Z
```
Which could be defined in Go as:
```go
type Config struct {
Age int
Cats []string
Pi float64
Perfection []int
DOB time.Time // requires `import time`
}
```
And then decoded with:
```go
var conf Config
if _, err := toml.Decode(tomlData, &conf); err != nil {
// handle error
}
```
You can also use struct tags if your struct field name doesn't map to a TOML
key value directly:
```toml
some_key_NAME = "wat"
```
```go
type TOML struct {
ObscureKey string `toml:"some_key_NAME"`
}
```
### Using the `encoding.TextUnmarshaler` interface
Here's an example that automatically parses duration strings into
`time.Duration` values:
```toml
[[song]]
name = "Thunder Road"
duration = "4m49s"
[[song]]
name = "Stairway to Heaven"
duration = "8m03s"
```
Which can be decoded with:
```go
type song struct {
Name string
Duration duration
}
type songs struct {
Song []song
}
var favorites songs
if _, err := toml.Decode(blob, &favorites); err != nil {
log.Fatal(err)
}
for _, s := range favorites.Song {
fmt.Printf("%s (%s)\n", s.Name, s.Duration)
}
```
And you'll also need a `duration` type that satisfies the
`encoding.TextUnmarshaler` interface:
```go
type duration struct {
time.Duration
}
func (d *duration) UnmarshalText(text []byte) error {
var err error
d.Duration, err = time.ParseDuration(string(text))
return err
}
```
### More complex usage
Here's an example of how to load the example from the official spec page:
```toml
# This is a TOML document. Boom.
title = "TOML Example"
[owner]
name = "Tom Preston-Werner"
organization = "GitHub"
bio = "GitHub Cofounder & CEO\nLikes tater tots and beer."
dob = 1979-05-27T07:32:00Z # First class dates? Why not?
[database]
server = "192.168.1.1"
ports = [ 8001, 8001, 8002 ]
connection_max = 5000
enabled = true
[servers]
# You can indent as you please. Tabs or spaces. TOML don't care.
[servers.alpha]
ip = "10.0.0.1"
dc = "eqdc10"
[servers.beta]
ip = "10.0.0.2"
dc = "eqdc10"
[clients]
data = [ ["gamma", "delta"], [1, 2] ] # just an update to make sure parsers support it
# Line breaks are OK when inside arrays
hosts = [
"alpha",
"omega"
]
```
And the corresponding Go types are:
```go
type tomlConfig struct {
Title string
Owner ownerInfo
DB database `toml:"database"`
Servers map[string]server
Clients clients
}
type ownerInfo struct {
Name string
Org string `toml:"organization"`
Bio string
DOB time.Time
}
type database struct {
Server string
Ports []int
ConnMax int `toml:"connection_max"`
Enabled bool
}
type server struct {
IP string
DC string
}
type clients struct {
Data [][]interface{}
Hosts []string
}
```
Note that a case insensitive match will be tried if an exact match can't be
found.
A working example of the above can be found in `_examples/example.{go,toml}`.

View File

@@ -1,61 +0,0 @@
package main
import (
"fmt"
"time"
"github.com/BurntSushi/toml"
)
type tomlConfig struct {
Title string
Owner ownerInfo
DB database `toml:"database"`
Servers map[string]server
Clients clients
}
type ownerInfo struct {
Name string
Org string `toml:"organization"`
Bio string
DOB time.Time
}
type database struct {
Server string
Ports []int
ConnMax int `toml:"connection_max"`
Enabled bool
}
type server struct {
IP string
DC string
}
type clients struct {
Data [][]interface{}
Hosts []string
}
func main() {
var config tomlConfig
if _, err := toml.DecodeFile("example.toml", &config); err != nil {
fmt.Println(err)
return
}
fmt.Printf("Title: %s\n", config.Title)
fmt.Printf("Owner: %s (%s, %s), Born: %s\n",
config.Owner.Name, config.Owner.Org, config.Owner.Bio,
config.Owner.DOB)
fmt.Printf("Database: %s %v (Max conn. %d), Enabled? %v\n",
config.DB.Server, config.DB.Ports, config.DB.ConnMax,
config.DB.Enabled)
for serverName, server := range config.Servers {
fmt.Printf("Server: %s (%s, %s)\n", serverName, server.IP, server.DC)
}
fmt.Printf("Client data: %v\n", config.Clients.Data)
fmt.Printf("Client hosts: %v\n", config.Clients.Hosts)
}

View File

@@ -1,35 +0,0 @@
# This is a TOML document. Boom.
title = "TOML Example"
[owner]
name = "Tom Preston-Werner"
organization = "GitHub"
bio = "GitHub Cofounder & CEO\nLikes tater tots and beer."
dob = 1979-05-27T07:32:00Z # First class dates? Why not?
[database]
server = "192.168.1.1"
ports = [ 8001, 8001, 8002 ]
connection_max = 5000
enabled = true
[servers]
# You can indent as you please. Tabs or spaces. TOML don't care.
[servers.alpha]
ip = "10.0.0.1"
dc = "eqdc10"
[servers.beta]
ip = "10.0.0.2"
dc = "eqdc10"
[clients]
data = [ ["gamma", "delta"], [1, 2] ] # just an update to make sure parsers support it
# Line breaks are OK when inside arrays
hosts = [
"alpha",
"omega"
]

View File

@@ -1,22 +0,0 @@
# Test file for TOML
# Only this one tries to emulate a TOML file written by a user of the kind of parser writers probably hate
# This part you'll really hate
[the]
test_string = "You'll hate me after this - #" # " Annoying, isn't it?
[the.hard]
test_array = [ "] ", " # "] # ] There you go, parse this!
test_array2 = [ "Test #11 ]proved that", "Experiment #9 was a success" ]
# You didn't think it'd as easy as chucking out the last #, did you?
another_test_string = " Same thing, but with a string #"
harder_test_string = " And when \"'s are in the string, along with # \"" # "and comments are there too"
# Things will get harder
[the.hard.bit#]
what? = "You don't think some user won't do that?"
multi_line_array = [
"]",
# ] Oh yes I did
]

View File

@@ -1,4 +0,0 @@
# [x] you
# [x.y] don't
# [x.y.z] need these
[x.y.z.w] # for this to work

View File

@@ -1,6 +0,0 @@
# DO NOT WANT
[fruit]
type = "apple"
[fruit.type]
apple = "yes"

View File

@@ -1,35 +0,0 @@
# This is an INVALID TOML document. Boom.
# Can you spot the error without help?
title = "TOML Example"
[owner]
name = "Tom Preston-Werner"
organization = "GitHub"
bio = "GitHub Cofounder & CEO\nLikes tater tots and beer."
dob = 1979-05-27T7:32:00Z # First class dates? Why not?
[database]
server = "192.168.1.1"
ports = [ 8001, 8001, 8002 ]
connection_max = 5000
enabled = true
[servers]
# You can indent as you please. Tabs or spaces. TOML don't care.
[servers.alpha]
ip = "10.0.0.1"
dc = "eqdc10"
[servers.beta]
ip = "10.0.0.2"
dc = "eqdc10"
[clients]
data = [ ["gamma", "delta"], [1, 2] ] # just an update to make sure parsers support it
# Line breaks are OK when inside arrays
hosts = [
"alpha",
"omega"
]

View File

@@ -1,5 +0,0 @@
Age = 25
Cats = [ "Cauchy", "Plato" ]
Pi = 3.14
Perfection = [ 6, 28, 496, 8128 ]
DOB = 1987-07-05T05:45:00Z

View File

@@ -1 +0,0 @@
some_key_NAME = "wat"

View File

@@ -1,13 +0,0 @@
# Implements the TOML test suite interface
This is an implementation of the interface expected by
[toml-test](https://github.com/BurntSushi/toml-test) for my
[toml parser written in Go](https://github.com/BurntSushi/toml).
In particular, it maps TOML data on `stdin` to a JSON format on `stdout`.
Compatible with TOML version
[v0.4.0](https://github.com/toml-lang/toml/blob/master/versions/en/toml-v0.4.0.md)
Compatible with `toml-test` version
[v0.2.0](https://github.com/BurntSushi/toml-test/tree/v0.2.0)

View File

@@ -1,90 +0,0 @@
// Command toml-test-decoder satisfies the toml-test interface for testing
// TOML decoders. Namely, it accepts TOML on stdin and outputs JSON on stdout.
package main
import (
"encoding/json"
"flag"
"fmt"
"log"
"os"
"path"
"time"
"github.com/BurntSushi/toml"
)
func init() {
log.SetFlags(0)
flag.Usage = usage
flag.Parse()
}
func usage() {
log.Printf("Usage: %s < toml-file\n", path.Base(os.Args[0]))
flag.PrintDefaults()
os.Exit(1)
}
func main() {
if flag.NArg() != 0 {
flag.Usage()
}
var tmp interface{}
if _, err := toml.DecodeReader(os.Stdin, &tmp); err != nil {
log.Fatalf("Error decoding TOML: %s", err)
}
typedTmp := translate(tmp)
if err := json.NewEncoder(os.Stdout).Encode(typedTmp); err != nil {
log.Fatalf("Error encoding JSON: %s", err)
}
}
func translate(tomlData interface{}) interface{} {
switch orig := tomlData.(type) {
case map[string]interface{}:
typed := make(map[string]interface{}, len(orig))
for k, v := range orig {
typed[k] = translate(v)
}
return typed
case []map[string]interface{}:
typed := make([]map[string]interface{}, len(orig))
for i, v := range orig {
typed[i] = translate(v).(map[string]interface{})
}
return typed
case []interface{}:
typed := make([]interface{}, len(orig))
for i, v := range orig {
typed[i] = translate(v)
}
// We don't really need to tag arrays, but let's be future proof.
// (If TOML ever supports tuples, we'll need this.)
return tag("array", typed)
case time.Time:
return tag("datetime", orig.Format("2006-01-02T15:04:05Z"))
case bool:
return tag("bool", fmt.Sprintf("%v", orig))
case int64:
return tag("integer", fmt.Sprintf("%d", orig))
case float64:
return tag("float", fmt.Sprintf("%v", orig))
case string:
return tag("string", orig)
}
panic(fmt.Sprintf("Unknown type: %T", tomlData))
}
func tag(typeName string, data interface{}) map[string]interface{} {
return map[string]interface{}{
"type": typeName,
"value": data,
}
}

View File

@@ -1,13 +0,0 @@
# Implements the TOML test suite interface for TOML encoders
This is an implementation of the interface expected by
[toml-test](https://github.com/BurntSushi/toml-test) for the
[TOML encoder](https://github.com/BurntSushi/toml).
In particular, it maps JSON data on `stdin` to a TOML format on `stdout`.
Compatible with TOML version
[v0.4.0](https://github.com/toml-lang/toml/blob/master/versions/en/toml-v0.4.0.md)
Compatible with `toml-test` version
[v0.2.0](https://github.com/BurntSushi/toml-test/tree/v0.2.0)

View File

@@ -1,131 +0,0 @@
// Command toml-test-encoder satisfies the toml-test interface for testing
// TOML encoders. Namely, it accepts JSON on stdin and outputs TOML on stdout.
package main
import (
"encoding/json"
"flag"
"log"
"os"
"path"
"strconv"
"time"
"github.com/BurntSushi/toml"
)
func init() {
log.SetFlags(0)
flag.Usage = usage
flag.Parse()
}
func usage() {
log.Printf("Usage: %s < json-file\n", path.Base(os.Args[0]))
flag.PrintDefaults()
os.Exit(1)
}
func main() {
if flag.NArg() != 0 {
flag.Usage()
}
var tmp interface{}
if err := json.NewDecoder(os.Stdin).Decode(&tmp); err != nil {
log.Fatalf("Error decoding JSON: %s", err)
}
tomlData := translate(tmp)
if err := toml.NewEncoder(os.Stdout).Encode(tomlData); err != nil {
log.Fatalf("Error encoding TOML: %s", err)
}
}
func translate(typedJson interface{}) interface{} {
switch v := typedJson.(type) {
case map[string]interface{}:
if len(v) == 2 && in("type", v) && in("value", v) {
return untag(v)
}
m := make(map[string]interface{}, len(v))
for k, v2 := range v {
m[k] = translate(v2)
}
return m
case []interface{}:
tabArray := make([]map[string]interface{}, len(v))
for i := range v {
if m, ok := translate(v[i]).(map[string]interface{}); ok {
tabArray[i] = m
} else {
log.Fatalf("JSON arrays may only contain objects. This " +
"corresponds to only tables being allowed in " +
"TOML table arrays.")
}
}
return tabArray
}
log.Fatalf("Unrecognized JSON format '%T'.", typedJson)
panic("unreachable")
}
func untag(typed map[string]interface{}) interface{} {
t := typed["type"].(string)
v := typed["value"]
switch t {
case "string":
return v.(string)
case "integer":
v := v.(string)
n, err := strconv.Atoi(v)
if err != nil {
log.Fatalf("Could not parse '%s' as integer: %s", v, err)
}
return n
case "float":
v := v.(string)
f, err := strconv.ParseFloat(v, 64)
if err != nil {
log.Fatalf("Could not parse '%s' as float64: %s", v, err)
}
return f
case "datetime":
v := v.(string)
t, err := time.Parse("2006-01-02T15:04:05Z", v)
if err != nil {
log.Fatalf("Could not parse '%s' as a datetime: %s", v, err)
}
return t
case "bool":
v := v.(string)
switch v {
case "true":
return true
case "false":
return false
}
log.Fatalf("Could not parse '%s' as a boolean.", v)
case "array":
v := v.([]interface{})
array := make([]interface{}, len(v))
for i := range v {
if m, ok := v[i].(map[string]interface{}); ok {
array[i] = untag(m)
} else {
log.Fatalf("Arrays may only contain other arrays or "+
"primitive values, but found a '%T'.", m)
}
}
return array
}
log.Fatalf("Unrecognized tag type '%s'.", t)
panic("unreachable")
}
func in(key string, m map[string]interface{}) bool {
_, ok := m[key]
return ok
}

View File

@@ -1,21 +0,0 @@
# TOML Validator
If Go is installed, it's simple to try it out:
```bash
go get github.com/BurntSushi/toml/cmd/tomlv
tomlv some-toml-file.toml
```
You can see the types of every key in a TOML file with:
```bash
tomlv -types some-toml-file.toml
```
At the moment, only one error message is reported at a time. Error messages
include line numbers. No output means that the files given are valid TOML, or
there is a bug in `tomlv`.
Compatible with TOML version
[v0.4.0](https://github.com/toml-lang/toml/blob/master/versions/en/toml-v0.4.0.md)

View File

@@ -1,61 +0,0 @@
// Command tomlv validates TOML documents and prints each key's type.
package main
import (
"flag"
"fmt"
"log"
"os"
"path"
"strings"
"text/tabwriter"
"github.com/BurntSushi/toml"
)
var (
flagTypes = false
)
func init() {
log.SetFlags(0)
flag.BoolVar(&flagTypes, "types", flagTypes,
"When set, the types of every defined key will be shown.")
flag.Usage = usage
flag.Parse()
}
func usage() {
log.Printf("Usage: %s toml-file [ toml-file ... ]\n",
path.Base(os.Args[0]))
flag.PrintDefaults()
os.Exit(1)
}
func main() {
if flag.NArg() < 1 {
flag.Usage()
}
for _, f := range flag.Args() {
var tmp interface{}
md, err := toml.DecodeFile(f, &tmp)
if err != nil {
log.Fatalf("Error in '%s': %s", f, err)
}
if flagTypes {
printTypes(md)
}
}
}
func printTypes(md toml.MetaData) {
tabw := tabwriter.NewWriter(os.Stdout, 0, 0, 2, ' ', 0)
for _, key := range md.Keys() {
fmt.Fprintf(tabw, "%s%s\t%s\n",
strings.Repeat(" ", len(key)-1), key, md.Type(key...))
}
tabw.Flush()
}

File diff suppressed because it is too large Load Diff

View File

@@ -1,615 +0,0 @@
package toml
import (
"bytes"
"fmt"
"log"
"net"
"testing"
"time"
)
func TestEncodeRoundTrip(t *testing.T) {
type Config struct {
Age int
Cats []string
Pi float64
Perfection []int
DOB time.Time
Ipaddress net.IP
}
var inputs = Config{
13,
[]string{"one", "two", "three"},
3.145,
[]int{11, 2, 3, 4},
time.Now(),
net.ParseIP("192.168.59.254"),
}
var firstBuffer bytes.Buffer
e := NewEncoder(&firstBuffer)
err := e.Encode(inputs)
if err != nil {
t.Fatal(err)
}
var outputs Config
if _, err := Decode(firstBuffer.String(), &outputs); err != nil {
t.Logf("Could not decode:\n-----\n%s\n-----\n",
firstBuffer.String())
t.Fatal(err)
}
// could test each value individually, but I'm lazy
var secondBuffer bytes.Buffer
e2 := NewEncoder(&secondBuffer)
err = e2.Encode(outputs)
if err != nil {
t.Fatal(err)
}
if firstBuffer.String() != secondBuffer.String() {
t.Error(
firstBuffer.String(),
"\n\n is not identical to\n\n",
secondBuffer.String())
}
}
// XXX(burntsushi)
// I think these tests probably should be removed. They are good, but they
// ought to be obsolete by toml-test.
func TestEncode(t *testing.T) {
type Embedded struct {
Int int `toml:"_int"`
}
type NonStruct int
date := time.Date(2014, 5, 11, 20, 30, 40, 0, time.FixedZone("IST", 3600))
dateStr := "2014-05-11T19:30:40Z"
tests := map[string]struct {
input interface{}
wantOutput string
wantError error
}{
"bool field": {
input: struct {
BoolTrue bool
BoolFalse bool
}{true, false},
wantOutput: "BoolTrue = true\nBoolFalse = false\n",
},
"int fields": {
input: struct {
Int int
Int8 int8
Int16 int16
Int32 int32
Int64 int64
}{1, 2, 3, 4, 5},
wantOutput: "Int = 1\nInt8 = 2\nInt16 = 3\nInt32 = 4\nInt64 = 5\n",
},
"uint fields": {
input: struct {
Uint uint
Uint8 uint8
Uint16 uint16
Uint32 uint32
Uint64 uint64
}{1, 2, 3, 4, 5},
wantOutput: "Uint = 1\nUint8 = 2\nUint16 = 3\nUint32 = 4" +
"\nUint64 = 5\n",
},
"float fields": {
input: struct {
Float32 float32
Float64 float64
}{1.5, 2.5},
wantOutput: "Float32 = 1.5\nFloat64 = 2.5\n",
},
"string field": {
input: struct{ String string }{"foo"},
wantOutput: "String = \"foo\"\n",
},
"string field and unexported field": {
input: struct {
String string
unexported int
}{"foo", 0},
wantOutput: "String = \"foo\"\n",
},
"datetime field in UTC": {
input: struct{ Date time.Time }{date},
wantOutput: fmt.Sprintf("Date = %s\n", dateStr),
},
"datetime field as primitive": {
// Using a map here to fail if isStructOrMap() returns true for
// time.Time.
input: map[string]interface{}{
"Date": date,
"Int": 1,
},
wantOutput: fmt.Sprintf("Date = %s\nInt = 1\n", dateStr),
},
"array fields": {
input: struct {
IntArray0 [0]int
IntArray3 [3]int
}{[0]int{}, [3]int{1, 2, 3}},
wantOutput: "IntArray0 = []\nIntArray3 = [1, 2, 3]\n",
},
"slice fields": {
input: struct{ IntSliceNil, IntSlice0, IntSlice3 []int }{
nil, []int{}, []int{1, 2, 3},
},
wantOutput: "IntSlice0 = []\nIntSlice3 = [1, 2, 3]\n",
},
"datetime slices": {
input: struct{ DatetimeSlice []time.Time }{
[]time.Time{date, date},
},
wantOutput: fmt.Sprintf("DatetimeSlice = [%s, %s]\n",
dateStr, dateStr),
},
"nested arrays and slices": {
input: struct {
SliceOfArrays [][2]int
ArrayOfSlices [2][]int
SliceOfArraysOfSlices [][2][]int
ArrayOfSlicesOfArrays [2][][2]int
SliceOfMixedArrays [][2]interface{}
ArrayOfMixedSlices [2][]interface{}
}{
[][2]int{{1, 2}, {3, 4}},
[2][]int{{1, 2}, {3, 4}},
[][2][]int{
{
{1, 2}, {3, 4},
},
{
{5, 6}, {7, 8},
},
},
[2][][2]int{
{
{1, 2}, {3, 4},
},
{
{5, 6}, {7, 8},
},
},
[][2]interface{}{
{1, 2}, {"a", "b"},
},
[2][]interface{}{
{1, 2}, {"a", "b"},
},
},
wantOutput: `SliceOfArrays = [[1, 2], [3, 4]]
ArrayOfSlices = [[1, 2], [3, 4]]
SliceOfArraysOfSlices = [[[1, 2], [3, 4]], [[5, 6], [7, 8]]]
ArrayOfSlicesOfArrays = [[[1, 2], [3, 4]], [[5, 6], [7, 8]]]
SliceOfMixedArrays = [[1, 2], ["a", "b"]]
ArrayOfMixedSlices = [[1, 2], ["a", "b"]]
`,
},
"empty slice": {
input: struct{ Empty []interface{} }{[]interface{}{}},
wantOutput: "Empty = []\n",
},
"(error) slice with element type mismatch (string and integer)": {
input: struct{ Mixed []interface{} }{[]interface{}{1, "a"}},
wantError: errArrayMixedElementTypes,
},
"(error) slice with element type mismatch (integer and float)": {
input: struct{ Mixed []interface{} }{[]interface{}{1, 2.5}},
wantError: errArrayMixedElementTypes,
},
"slice with elems of differing Go types, same TOML types": {
input: struct {
MixedInts []interface{}
MixedFloats []interface{}
}{
[]interface{}{
int(1), int8(2), int16(3), int32(4), int64(5),
uint(1), uint8(2), uint16(3), uint32(4), uint64(5),
},
[]interface{}{float32(1.5), float64(2.5)},
},
wantOutput: "MixedInts = [1, 2, 3, 4, 5, 1, 2, 3, 4, 5]\n" +
"MixedFloats = [1.5, 2.5]\n",
},
"(error) slice w/ element type mismatch (one is nested array)": {
input: struct{ Mixed []interface{} }{
[]interface{}{1, []interface{}{2}},
},
wantError: errArrayMixedElementTypes,
},
"(error) slice with 1 nil element": {
input: struct{ NilElement1 []interface{} }{[]interface{}{nil}},
wantError: errArrayNilElement,
},
"(error) slice with 1 nil element (and other non-nil elements)": {
input: struct{ NilElement []interface{} }{
[]interface{}{1, nil},
},
wantError: errArrayNilElement,
},
"simple map": {
input: map[string]int{"a": 1, "b": 2},
wantOutput: "a = 1\nb = 2\n",
},
"map with interface{} value type": {
input: map[string]interface{}{"a": 1, "b": "c"},
wantOutput: "a = 1\nb = \"c\"\n",
},
"map with interface{} value type, some of which are structs": {
input: map[string]interface{}{
"a": struct{ Int int }{2},
"b": 1,
},
wantOutput: "b = 1\n\n[a]\n Int = 2\n",
},
"nested map": {
input: map[string]map[string]int{
"a": {"b": 1},
"c": {"d": 2},
},
wantOutput: "[a]\n b = 1\n\n[c]\n d = 2\n",
},
"nested struct": {
input: struct{ Struct struct{ Int int } }{
struct{ Int int }{1},
},
wantOutput: "[Struct]\n Int = 1\n",
},
"nested struct and non-struct field": {
input: struct {
Struct struct{ Int int }
Bool bool
}{struct{ Int int }{1}, true},
wantOutput: "Bool = true\n\n[Struct]\n Int = 1\n",
},
"2 nested structs": {
input: struct{ Struct1, Struct2 struct{ Int int } }{
struct{ Int int }{1}, struct{ Int int }{2},
},
wantOutput: "[Struct1]\n Int = 1\n\n[Struct2]\n Int = 2\n",
},
"deeply nested structs": {
input: struct {
Struct1, Struct2 struct{ Struct3 *struct{ Int int } }
}{
struct{ Struct3 *struct{ Int int } }{&struct{ Int int }{1}},
struct{ Struct3 *struct{ Int int } }{nil},
},
wantOutput: "[Struct1]\n [Struct1.Struct3]\n Int = 1" +
"\n\n[Struct2]\n",
},
"nested struct with nil struct elem": {
input: struct {
Struct struct{ Inner *struct{ Int int } }
}{
struct{ Inner *struct{ Int int } }{nil},
},
wantOutput: "[Struct]\n",
},
"nested struct with no fields": {
input: struct {
Struct struct{ Inner struct{} }
}{
struct{ Inner struct{} }{struct{}{}},
},
wantOutput: "[Struct]\n [Struct.Inner]\n",
},
"struct with tags": {
input: struct {
Struct struct {
Int int `toml:"_int"`
} `toml:"_struct"`
Bool bool `toml:"_bool"`
}{
struct {
Int int `toml:"_int"`
}{1}, true,
},
wantOutput: "_bool = true\n\n[_struct]\n _int = 1\n",
},
"embedded struct": {
input: struct{ Embedded }{Embedded{1}},
wantOutput: "_int = 1\n",
},
"embedded *struct": {
input: struct{ *Embedded }{&Embedded{1}},
wantOutput: "_int = 1\n",
},
"nested embedded struct": {
input: struct {
Struct struct{ Embedded } `toml:"_struct"`
}{struct{ Embedded }{Embedded{1}}},
wantOutput: "[_struct]\n _int = 1\n",
},
"nested embedded *struct": {
input: struct {
Struct struct{ *Embedded } `toml:"_struct"`
}{struct{ *Embedded }{&Embedded{1}}},
wantOutput: "[_struct]\n _int = 1\n",
},
"embedded non-struct": {
input: struct{ NonStruct }{5},
wantOutput: "NonStruct = 5\n",
},
"array of tables": {
input: struct {
Structs []*struct{ Int int } `toml:"struct"`
}{
[]*struct{ Int int }{{1}, {3}},
},
wantOutput: "[[struct]]\n Int = 1\n\n[[struct]]\n Int = 3\n",
},
"array of tables order": {
input: map[string]interface{}{
"map": map[string]interface{}{
"zero": 5,
"arr": []map[string]int{
{
"friend": 5,
},
},
},
},
wantOutput: "[map]\n zero = 5\n\n [[map.arr]]\n friend = 5\n",
},
"(error) top-level slice": {
input: []struct{ Int int }{{1}, {2}, {3}},
wantError: errNoKey,
},
"(error) slice of slice": {
input: struct {
Slices [][]struct{ Int int }
}{
[][]struct{ Int int }{{{1}}, {{2}}, {{3}}},
},
wantError: errArrayNoTable,
},
"(error) map no string key": {
input: map[int]string{1: ""},
wantError: errNonString,
},
"(error) empty key name": {
input: map[string]int{"": 1},
wantError: errAnything,
},
"(error) empty map name": {
input: map[string]interface{}{
"": map[string]int{"v": 1},
},
wantError: errAnything,
},
}
for label, test := range tests {
encodeExpected(t, label, test.input, test.wantOutput, test.wantError)
}
}
func TestEncodeNestedTableArrays(t *testing.T) {
type song struct {
Name string `toml:"name"`
}
type album struct {
Name string `toml:"name"`
Songs []song `toml:"songs"`
}
type springsteen struct {
Albums []album `toml:"albums"`
}
value := springsteen{
[]album{
{"Born to Run",
[]song{{"Jungleland"}, {"Meeting Across the River"}}},
{"Born in the USA",
[]song{{"Glory Days"}, {"Dancing in the Dark"}}},
},
}
expected := `[[albums]]
name = "Born to Run"
[[albums.songs]]
name = "Jungleland"
[[albums.songs]]
name = "Meeting Across the River"
[[albums]]
name = "Born in the USA"
[[albums.songs]]
name = "Glory Days"
[[albums.songs]]
name = "Dancing in the Dark"
`
encodeExpected(t, "nested table arrays", value, expected, nil)
}
func TestEncodeArrayHashWithNormalHashOrder(t *testing.T) {
type Alpha struct {
V int
}
type Beta struct {
V int
}
type Conf struct {
V int
A Alpha
B []Beta
}
val := Conf{
V: 1,
A: Alpha{2},
B: []Beta{{3}},
}
expected := "V = 1\n\n[A]\n V = 2\n\n[[B]]\n V = 3\n"
encodeExpected(t, "array hash with normal hash order", val, expected, nil)
}
func TestEncodeWithOmitEmpty(t *testing.T) {
type simple struct {
Bool bool `toml:"bool,omitempty"`
String string `toml:"string,omitempty"`
Array [0]byte `toml:"array,omitempty"`
Slice []int `toml:"slice,omitempty"`
Map map[string]string `toml:"map,omitempty"`
}
var v simple
encodeExpected(t, "fields with omitempty are omitted when empty", v, "", nil)
v = simple{
Bool: true,
String: " ",
Slice: []int{2, 3, 4},
Map: map[string]string{"foo": "bar"},
}
expected := `bool = true
string = " "
slice = [2, 3, 4]
[map]
foo = "bar"
`
encodeExpected(t, "fields with omitempty are not omitted when non-empty",
v, expected, nil)
}
func TestEncodeWithOmitZero(t *testing.T) {
type simple struct {
Number int `toml:"number,omitzero"`
Real float64 `toml:"real,omitzero"`
Unsigned uint `toml:"unsigned,omitzero"`
}
value := simple{0, 0.0, uint(0)}
expected := ""
encodeExpected(t, "simple with omitzero, all zero", value, expected, nil)
value.Number = 10
value.Real = 20
value.Unsigned = 5
expected = `number = 10
real = 20.0
unsigned = 5
`
encodeExpected(t, "simple with omitzero, non-zero", value, expected, nil)
}
func TestEncodeOmitemptyWithEmptyName(t *testing.T) {
type simple struct {
S []int `toml:",omitempty"`
}
v := simple{[]int{1, 2, 3}}
expected := "S = [1, 2, 3]\n"
encodeExpected(t, "simple with omitempty, no name, non-empty field",
v, expected, nil)
}
func TestEncodeAnonymousStruct(t *testing.T) {
type Inner struct{ N int }
type Outer0 struct{ Inner }
type Outer1 struct {
Inner `toml:"inner"`
}
v0 := Outer0{Inner{3}}
expected := "N = 3\n"
encodeExpected(t, "embedded anonymous untagged struct", v0, expected, nil)
v1 := Outer1{Inner{3}}
expected = "[inner]\n N = 3\n"
encodeExpected(t, "embedded anonymous tagged struct", v1, expected, nil)
}
func TestEncodeAnonymousStructPointerField(t *testing.T) {
type Inner struct{ N int }
type Outer0 struct{ *Inner }
type Outer1 struct {
*Inner `toml:"inner"`
}
v0 := Outer0{}
expected := ""
encodeExpected(t, "nil anonymous untagged struct pointer field", v0, expected, nil)
v0 = Outer0{&Inner{3}}
expected = "N = 3\n"
encodeExpected(t, "non-nil anonymous untagged struct pointer field", v0, expected, nil)
v1 := Outer1{}
expected = ""
encodeExpected(t, "nil anonymous tagged struct pointer field", v1, expected, nil)
v1 = Outer1{&Inner{3}}
expected = "[inner]\n N = 3\n"
encodeExpected(t, "non-nil anonymous tagged struct pointer field", v1, expected, nil)
}
func TestEncodeIgnoredFields(t *testing.T) {
type simple struct {
Number int `toml:"-"`
}
value := simple{}
expected := ""
encodeExpected(t, "ignored field", value, expected, nil)
}
func encodeExpected(
t *testing.T, label string, val interface{}, wantStr string, wantErr error,
) {
var buf bytes.Buffer
enc := NewEncoder(&buf)
err := enc.Encode(val)
if err != wantErr {
if wantErr != nil {
if wantErr == errAnything && err != nil {
return
}
t.Errorf("%s: want Encode error %v, got %v", label, wantErr, err)
} else {
t.Errorf("%s: Encode failed: %s", label, err)
}
}
if err != nil {
return
}
if got := buf.String(); wantStr != got {
t.Errorf("%s: want\n-----\n%q\n-----\nbut got\n-----\n%q\n-----\n",
label, wantStr, got)
}
}
func ExampleEncoder_Encode() {
date, _ := time.Parse(time.RFC822, "14 Mar 10 18:00 UTC")
var config = map[string]interface{}{
"date": date,
"counts": []int{1, 1, 2, 3, 5, 8},
"hash": map[string]string{
"key1": "val1",
"key2": "val2",
},
}
buf := new(bytes.Buffer)
if err := NewEncoder(buf).Encode(config); err != nil {
log.Fatal(err)
}
fmt.Println(buf.String())
// Output:
// counts = [1, 1, 2, 3, 5, 8]
// date = 2010-03-14T18:00:00Z
//
// [hash]
// key1 = "val1"
// key2 = "val2"
}

View File

@@ -1 +0,0 @@
au BufWritePost *.go silent!make tags > /dev/null 2>&1

View File

@@ -1 +0,0 @@
*.exe

View File

@@ -1,22 +0,0 @@
# go-winio
This repository contains utilities for efficiently performing Win32 IO operations in
Go. Currently, this is focused on accessing named pipes and other file handles, and
for using named pipes as a net transport.
This code relies on IO completion ports to avoid blocking IO on system threads, allowing Go
to reuse the thread to schedule another goroutine. This limits support to Windows Vista and
newer operating systems. This is similar to the implementation of network sockets in Go's net
package.
Please see the LICENSE file for licensing information.
This project has adopted the [Microsoft Open Source Code of
Conduct](https://opensource.microsoft.com/codeofconduct/). For more information
see the [Code of Conduct
FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or contact
[opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional
questions or comments.
Thanks to natefinch for the inspiration for this library. See https://github.com/natefinch/npipe
for another named pipe implementation.

View File

@@ -1,344 +0,0 @@
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package tar implements access to tar archives.
// It aims to cover most of the variations, including those produced
// by GNU and BSD tars.
//
// References:
// http://www.freebsd.org/cgi/man.cgi?query=tar&sektion=5
// http://www.gnu.org/software/tar/manual/html_node/Standard.html
// http://pubs.opengroup.org/onlinepubs/9699919799/utilities/pax.html
package tar
import (
"bytes"
"errors"
"fmt"
"os"
"path"
"time"
)
const (
blockSize = 512
// Types
TypeReg = '0' // regular file
TypeRegA = '\x00' // regular file
TypeLink = '1' // hard link
TypeSymlink = '2' // symbolic link
TypeChar = '3' // character device node
TypeBlock = '4' // block device node
TypeDir = '5' // directory
TypeFifo = '6' // fifo node
TypeCont = '7' // reserved
TypeXHeader = 'x' // extended header
TypeXGlobalHeader = 'g' // global extended header
TypeGNULongName = 'L' // Next file has a long name
TypeGNULongLink = 'K' // Next file symlinks to a file w/ a long name
TypeGNUSparse = 'S' // sparse file
)
// A Header represents a single header in a tar archive.
// Some fields may not be populated.
type Header struct {
Name string // name of header file entry
Mode int64 // permission and mode bits
Uid int // user id of owner
Gid int // group id of owner
Size int64 // length in bytes
ModTime time.Time // modified time
Typeflag byte // type of header entry
Linkname string // target name of link
Uname string // user name of owner
Gname string // group name of owner
Devmajor int64 // major number of character or block device
Devminor int64 // minor number of character or block device
AccessTime time.Time // access time
ChangeTime time.Time // status change time
CreationTime time.Time // creation time
Xattrs map[string]string
Winheaders map[string]string
}
// File name constants from the tar spec.
const (
fileNameSize = 100 // Maximum number of bytes in a standard tar name.
fileNamePrefixSize = 155 // Maximum number of ustar extension bytes.
)
// FileInfo returns an os.FileInfo for the Header.
func (h *Header) FileInfo() os.FileInfo {
return headerFileInfo{h}
}
// headerFileInfo implements os.FileInfo.
type headerFileInfo struct {
h *Header
}
func (fi headerFileInfo) Size() int64 { return fi.h.Size }
func (fi headerFileInfo) IsDir() bool { return fi.Mode().IsDir() }
func (fi headerFileInfo) ModTime() time.Time { return fi.h.ModTime }
func (fi headerFileInfo) Sys() interface{} { return fi.h }
// Name returns the base name of the file.
func (fi headerFileInfo) Name() string {
if fi.IsDir() {
return path.Base(path.Clean(fi.h.Name))
}
return path.Base(fi.h.Name)
}
// Mode returns the permission and mode bits for the headerFileInfo.
func (fi headerFileInfo) Mode() (mode os.FileMode) {
// Set file permission bits.
mode = os.FileMode(fi.h.Mode).Perm()
// Set setuid, setgid and sticky bits.
if fi.h.Mode&c_ISUID != 0 {
// setuid
mode |= os.ModeSetuid
}
if fi.h.Mode&c_ISGID != 0 {
// setgid
mode |= os.ModeSetgid
}
if fi.h.Mode&c_ISVTX != 0 {
// sticky
mode |= os.ModeSticky
}
// Set file mode bits.
// clear perm, setuid, setgid and sticky bits.
m := os.FileMode(fi.h.Mode) &^ 07777
if m == c_ISDIR {
// directory
mode |= os.ModeDir
}
if m == c_ISFIFO {
// named pipe (FIFO)
mode |= os.ModeNamedPipe
}
if m == c_ISLNK {
// symbolic link
mode |= os.ModeSymlink
}
if m == c_ISBLK {
// device file
mode |= os.ModeDevice
}
if m == c_ISCHR {
// Unix character device
mode |= os.ModeDevice
mode |= os.ModeCharDevice
}
if m == c_ISSOCK {
// Unix domain socket
mode |= os.ModeSocket
}
switch fi.h.Typeflag {
case TypeSymlink:
// symbolic link
mode |= os.ModeSymlink
case TypeChar:
// character device node
mode |= os.ModeDevice
mode |= os.ModeCharDevice
case TypeBlock:
// block device node
mode |= os.ModeDevice
case TypeDir:
// directory
mode |= os.ModeDir
case TypeFifo:
// fifo node
mode |= os.ModeNamedPipe
}
return mode
}
// sysStat, if non-nil, populates h from system-dependent fields of fi.
var sysStat func(fi os.FileInfo, h *Header) error
// Mode constants from the tar spec.
const (
c_ISUID = 04000 // Set uid
c_ISGID = 02000 // Set gid
c_ISVTX = 01000 // Save text (sticky bit)
c_ISDIR = 040000 // Directory
c_ISFIFO = 010000 // FIFO
c_ISREG = 0100000 // Regular file
c_ISLNK = 0120000 // Symbolic link
c_ISBLK = 060000 // Block special file
c_ISCHR = 020000 // Character special file
c_ISSOCK = 0140000 // Socket
)
// Keywords for the PAX Extended Header
const (
paxAtime = "atime"
paxCharset = "charset"
paxComment = "comment"
paxCtime = "ctime" // please note that ctime is not a valid pax header.
paxCreationTime = "LIBARCHIVE.creationtime"
paxGid = "gid"
paxGname = "gname"
paxLinkpath = "linkpath"
paxMtime = "mtime"
paxPath = "path"
paxSize = "size"
paxUid = "uid"
paxUname = "uname"
paxXattr = "SCHILY.xattr."
paxWindows = "MSWINDOWS."
paxNone = ""
)
// FileInfoHeader creates a partially-populated Header from fi.
// If fi describes a symlink, FileInfoHeader records link as the link target.
// If fi describes a directory, a slash is appended to the name.
// Because os.FileInfo's Name method returns only the base name of
// the file it describes, it may be necessary to modify the Name field
// of the returned header to provide the full path name of the file.
func FileInfoHeader(fi os.FileInfo, link string) (*Header, error) {
if fi == nil {
return nil, errors.New("tar: FileInfo is nil")
}
fm := fi.Mode()
h := &Header{
Name: fi.Name(),
ModTime: fi.ModTime(),
Mode: int64(fm.Perm()), // or'd with c_IS* constants later
}
switch {
case fm.IsRegular():
h.Mode |= c_ISREG
h.Typeflag = TypeReg
h.Size = fi.Size()
case fi.IsDir():
h.Typeflag = TypeDir
h.Mode |= c_ISDIR
h.Name += "/"
case fm&os.ModeSymlink != 0:
h.Typeflag = TypeSymlink
h.Mode |= c_ISLNK
h.Linkname = link
case fm&os.ModeDevice != 0:
if fm&os.ModeCharDevice != 0 {
h.Mode |= c_ISCHR
h.Typeflag = TypeChar
} else {
h.Mode |= c_ISBLK
h.Typeflag = TypeBlock
}
case fm&os.ModeNamedPipe != 0:
h.Typeflag = TypeFifo
h.Mode |= c_ISFIFO
case fm&os.ModeSocket != 0:
h.Mode |= c_ISSOCK
default:
return nil, fmt.Errorf("archive/tar: unknown file mode %v", fm)
}
if fm&os.ModeSetuid != 0 {
h.Mode |= c_ISUID
}
if fm&os.ModeSetgid != 0 {
h.Mode |= c_ISGID
}
if fm&os.ModeSticky != 0 {
h.Mode |= c_ISVTX
}
// If possible, populate additional fields from OS-specific
// FileInfo fields.
if sys, ok := fi.Sys().(*Header); ok {
// This FileInfo came from a Header (not the OS). Use the
// original Header to populate all remaining fields.
h.Uid = sys.Uid
h.Gid = sys.Gid
h.Uname = sys.Uname
h.Gname = sys.Gname
h.AccessTime = sys.AccessTime
h.ChangeTime = sys.ChangeTime
if sys.Xattrs != nil {
h.Xattrs = make(map[string]string)
for k, v := range sys.Xattrs {
h.Xattrs[k] = v
}
}
if sys.Typeflag == TypeLink {
// hard link
h.Typeflag = TypeLink
h.Size = 0
h.Linkname = sys.Linkname
}
}
if sysStat != nil {
return h, sysStat(fi, h)
}
return h, nil
}
var zeroBlock = make([]byte, blockSize)
// POSIX specifies a sum of the unsigned byte values, but the Sun tar uses signed byte values.
// We compute and return both.
func checksum(header []byte) (unsigned int64, signed int64) {
for i := 0; i < len(header); i++ {
if i == 148 {
// The chksum field (header[148:156]) is special: it should be treated as space bytes.
unsigned += ' ' * 8
signed += ' ' * 8
i += 7
continue
}
unsigned += int64(header[i])
signed += int64(int8(header[i]))
}
return
}
type slicer []byte
func (sp *slicer) next(n int) (b []byte) {
s := *sp
b, *sp = s[0:n], s[n:]
return
}
func isASCII(s string) bool {
for _, c := range s {
if c >= 0x80 {
return false
}
}
return true
}
func toASCII(s string) string {
if isASCII(s) {
return s
}
var buf bytes.Buffer
for _, c := range s {
if c < 0x80 {
buf.WriteByte(byte(c))
}
}
return buf.String()
}
// isHeaderOnlyType checks if the given type flag is of the type that has no
// data section even if a size is specified.
func isHeaderOnlyType(flag byte) bool {
switch flag {
case TypeLink, TypeSymlink, TypeChar, TypeBlock, TypeDir, TypeFifo:
return true
default:
return false
}
}

View File

@@ -1,80 +0,0 @@
// Copyright 2013 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package tar_test
import (
"archive/tar"
"bytes"
"fmt"
"io"
"log"
"os"
)
func Example() {
// Create a buffer to write our archive to.
buf := new(bytes.Buffer)
// Create a new tar archive.
tw := tar.NewWriter(buf)
// Add some files to the archive.
var files = []struct {
Name, Body string
}{
{"readme.txt", "This archive contains some text files."},
{"gopher.txt", "Gopher names:\nGeorge\nGeoffrey\nGonzo"},
{"todo.txt", "Get animal handling license."},
}
for _, file := range files {
hdr := &tar.Header{
Name: file.Name,
Mode: 0600,
Size: int64(len(file.Body)),
}
if err := tw.WriteHeader(hdr); err != nil {
log.Fatalln(err)
}
if _, err := tw.Write([]byte(file.Body)); err != nil {
log.Fatalln(err)
}
}
// Make sure to check the error on Close.
if err := tw.Close(); err != nil {
log.Fatalln(err)
}
// Open the tar archive for reading.
r := bytes.NewReader(buf.Bytes())
tr := tar.NewReader(r)
// Iterate through the files in the archive.
for {
hdr, err := tr.Next()
if err == io.EOF {
// end of tar archive
break
}
if err != nil {
log.Fatalln(err)
}
fmt.Printf("Contents of %s:\n", hdr.Name)
if _, err := io.Copy(os.Stdout, tr); err != nil {
log.Fatalln(err)
}
fmt.Println()
}
// Output:
// Contents of readme.txt:
// This archive contains some text files.
// Contents of gopher.txt:
// Gopher names:
// George
// Geoffrey
// Gonzo
// Contents of todo.txt:
// Get animal handling license.
}

File diff suppressed because it is too large Load Diff

Some files were not shown because too many files have changed in this diff Show More