diff --git a/cmd/virtual-kubelet/internal/provider/mock/mock.go b/cmd/virtual-kubelet/internal/provider/mock/mock.go index de4757117..eeb130e42 100644 --- a/cmd/virtual-kubelet/internal/provider/mock/mock.go +++ b/cmd/virtual-kubelet/internal/provider/mock/mock.go @@ -308,6 +308,12 @@ func (p *MockProvider) AttachToContainer(ctx context.Context, namespace, name, c return nil } +// PortForward forwards a local port to a port on the pod +func (p *MockProvider) PortForward(ctx context.Context, namespace, pod string, port int32, stream io.ReadWriteCloser) error { + log.G(ctx).Infof("receive PortForward %q", pod) + return nil +} + // GetPodStatus returns the status of a pod by name that is "running". // returns nil if a pod by that name is not found. func (p *MockProvider) GetPodStatus(ctx context.Context, namespace, name string) (*v1.PodStatus, error) { diff --git a/internal/kubernetes/portforward/constants.go b/internal/kubernetes/portforward/constants.go new file mode 100644 index 000000000..62b14f205 --- /dev/null +++ b/internal/kubernetes/portforward/constants.go @@ -0,0 +1,24 @@ +/* +Copyright 2015 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +// Package portforward contains server-side logic for handling port forwarding requests. +package portforward + +// ProtocolV1Name is the name of the subprotocol used for port forwarding. +const ProtocolV1Name = "portforward.k8s.io" + +// SupportedProtocols are the supported port forwarding protocols. +var SupportedProtocols = []string{ProtocolV1Name} diff --git a/internal/kubernetes/portforward/httpstream.go b/internal/kubernetes/portforward/httpstream.go new file mode 100644 index 000000000..ea5ce9880 --- /dev/null +++ b/internal/kubernetes/portforward/httpstream.go @@ -0,0 +1,317 @@ +/* +Copyright 2016 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package portforward + +import ( + "context" + "errors" + "fmt" + "net/http" + "strconv" + "sync" + "time" + + api "k8s.io/api/core/v1" + "k8s.io/apimachinery/pkg/types" + "k8s.io/apimachinery/pkg/util/httpstream" + "k8s.io/apimachinery/pkg/util/httpstream/spdy" + utilruntime "k8s.io/apimachinery/pkg/util/runtime" + + "k8s.io/klog/v2" +) + +func handleHTTPStreams(req *http.Request, w http.ResponseWriter, portForwarder PortForwarder, podName string, uid types.UID, supportedPortForwardProtocols []string, idleTimeout, streamCreationTimeout time.Duration) error { + _, err := httpstream.Handshake(req, w, supportedPortForwardProtocols) + // negotiated protocol isn't currently used server side, but could be in the future + if err != nil { + // Handshake writes the error to the client + return err + } + streamChan := make(chan httpstream.Stream, 1) + + klog.V(5).InfoS("Upgrading port forward response") + + // TODO aka-somix: SPDY is deprecated and it should be replaced in order to support HTTP/2 + upgrader := spdy.NewResponseUpgrader() + conn := upgrader.UpgradeResponse(w, req, httpStreamReceived(streamChan)) + if conn == nil { + return errors.New("unable to upgrade httpstream connection") + } + defer conn.Close() + + klog.V(5).InfoS("Connection setting port forwarding streaming connection idle timeout", "connection", conn, "idleTimeout", idleTimeout) + conn.SetIdleTimeout(idleTimeout) + + h := &httpStreamHandler{ + conn: conn, + streamChan: streamChan, + streamPairs: make(map[string]*httpStreamPair), + streamCreationTimeout: streamCreationTimeout, + pod: podName, + uid: uid, + forwarder: portForwarder, + } + h.run() + + return nil +} + +// httpStreamReceived is the httpstream.NewStreamHandler for port +// forward streams. It checks each stream's port and stream type headers, +// rejecting any streams that with missing or invalid values. Each valid +// stream is sent to the streams channel. +func httpStreamReceived(streams chan httpstream.Stream) func(httpstream.Stream, <-chan struct{}) error { + return func(stream httpstream.Stream, replySent <-chan struct{}) error { + // make sure it has a valid port header + portString := stream.Headers().Get(api.PortHeader) + if len(portString) == 0 { + return fmt.Errorf("%q header is required", api.PortHeader) + } + port, err := strconv.ParseUint(portString, 10, 16) + if err != nil { + return fmt.Errorf("unable to parse %q as a port: %v", portString, err) + } + if port < 1 { + return fmt.Errorf("port %q must be > 0", portString) + } + + // make sure it has a valid stream type header + streamType := stream.Headers().Get(api.StreamType) + if len(streamType) == 0 { + return fmt.Errorf("%q header is required", api.StreamType) + } + if streamType != api.StreamTypeError && streamType != api.StreamTypeData { + return fmt.Errorf("invalid stream type %q", streamType) + } + + streams <- stream + return nil + } +} + +// httpStreamHandler is capable of processing multiple port forward +// requests over a single httpstream.Connection. +type httpStreamHandler struct { + conn httpstream.Connection + streamChan chan httpstream.Stream + streamPairsLock sync.RWMutex + streamPairs map[string]*httpStreamPair + streamCreationTimeout time.Duration + pod string + uid types.UID + forwarder PortForwarder +} + +// getStreamPair returns a httpStreamPair for requestID. This creates a +// new pair if one does not yet exist for the requestID. The returned bool is +// true if the pair was created. +func (h *httpStreamHandler) getStreamPair(requestID string) (*httpStreamPair, bool) { + h.streamPairsLock.Lock() + defer h.streamPairsLock.Unlock() + + if p, ok := h.streamPairs[requestID]; ok { + klog.V(5).InfoS("Connection request found existing stream pair", "connection", h.conn, "request", requestID) + return p, false + } + + klog.V(5).InfoS("Connection request creating new stream pair", "connection", h.conn, "request", requestID) + + p := newPortForwardPair(requestID) + h.streamPairs[requestID] = p + + return p, true +} + +// monitorStreamPair waits for the pair to receive both its error and data +// streams, or for the timeout to expire (whichever happens first), and then +// removes the pair. +func (h *httpStreamHandler) monitorStreamPair(p *httpStreamPair, timeout <-chan time.Time) { + select { + case <-timeout: + err := fmt.Errorf("(conn=%v, request=%s) timed out waiting for streams", h.conn, p.requestID) + utilruntime.HandleError(err) + p.printError(err.Error()) + case <-p.complete: + klog.V(5).InfoS("Connection request successfully received error and data streams", "connection", h.conn, "request", p.requestID) + } + h.removeStreamPair(p.requestID) +} + +// hasStreamPair returns a bool indicating if a stream pair for requestID +// exists. +func (h *httpStreamHandler) hasStreamPair(requestID string) bool { + h.streamPairsLock.RLock() + defer h.streamPairsLock.RUnlock() + + _, ok := h.streamPairs[requestID] + return ok +} + +// removeStreamPair removes the stream pair identified by requestID from streamPairs. +func (h *httpStreamHandler) removeStreamPair(requestID string) { + h.streamPairsLock.Lock() + defer h.streamPairsLock.Unlock() + + if h.conn != nil { + pair := h.streamPairs[requestID] + h.conn.RemoveStreams(pair.dataStream, pair.errorStream) + } + delete(h.streamPairs, requestID) +} + +// requestID returns the request id for stream. +func (h *httpStreamHandler) requestID(stream httpstream.Stream) string { + requestID := stream.Headers().Get(api.PortForwardRequestIDHeader) + if len(requestID) == 0 { + klog.V(5).InfoS("Connection stream received without requestID header", "connection", h.conn) + // If we get here, it's because the connection came from an older client + // that isn't generating the request id header + // (https://github.com/kubernetes/kubernetes/blob/843134885e7e0b360eb5441e85b1410a8b1a7a0c/pkg/client/unversioned/portforward/portforward.go#L258-L287) + // + // This is a best-effort attempt at supporting older clients. + // + // When there aren't concurrent new forwarded connections, each connection + // will have a pair of streams (data, error), and the stream IDs will be + // consecutive odd numbers, e.g. 1 and 3 for the first connection. Convert + // the stream ID into a pseudo-request id by taking the stream type and + // using id = stream.Identifier() when the stream type is error, + // and id = stream.Identifier() - 2 when it's data. + // + // NOTE: this only works when there are not concurrent new streams from + // multiple forwarded connections; it's a best-effort attempt at supporting + // old clients that don't generate request ids. If there are concurrent + // new connections, it's possible that 1 connection gets streams whose IDs + // are not consecutive (e.g. 5 and 9 instead of 5 and 7). + streamType := stream.Headers().Get(api.StreamType) + switch streamType { + case api.StreamTypeError: + requestID = strconv.Itoa(int(stream.Identifier())) + case api.StreamTypeData: + requestID = strconv.Itoa(int(stream.Identifier()) - 2) + } + + klog.V(5).InfoS("Connection automatically assigning request ID from stream type and stream ID", "connection", h.conn, "request", requestID, "streamType", streamType, "stream", stream.Identifier()) + } + return requestID +} + +// run is the main loop for the httpStreamHandler. It processes new +// streams, invoking portForward for each complete stream pair. The loop exits +// when the httpstream.Connection is closed. +func (h *httpStreamHandler) run() { + klog.V(5).InfoS("Connection waiting for port forward streams", "connection", h.conn) +Loop: + for { + select { + case <-h.conn.CloseChan(): + klog.V(5).InfoS("Connection upgraded connection closed", "connection", h.conn) + break Loop + case stream := <-h.streamChan: + requestID := h.requestID(stream) + streamType := stream.Headers().Get(api.StreamType) + klog.V(5).InfoS("Connection request received new type of stream", "connection", h.conn, "request", requestID, "streamType", streamType) + + p, created := h.getStreamPair(requestID) + if created { + go h.monitorStreamPair(p, time.After(h.streamCreationTimeout)) + } + if complete, err := p.add(stream); err != nil { + msg := fmt.Sprintf("error processing stream for request %s: %v", requestID, err) + utilruntime.HandleError(errors.New(msg)) + p.printError(msg) + } else if complete { + go h.portForward(p) + } + } + } +} + +// portForward invokes the httpStreamHandler's forwarder.PortForward +// function for the given stream pair. +func (h *httpStreamHandler) portForward(p *httpStreamPair) { + ctx := context.Background() + defer p.dataStream.Close() + defer p.errorStream.Close() + + portString := p.dataStream.Headers().Get(api.PortHeader) + port, _ := strconv.ParseInt(portString, 10, 32) + + klog.V(5).InfoS("Connection request invoking forwarder.PortForward for port", "connection", h.conn, "request", p.requestID, "port", portString) + err := h.forwarder.PortForward(ctx, h.pod, h.uid, int32(port), p.dataStream) + klog.V(5).InfoS("Connection request done invoking forwarder.PortForward for port", "connection", h.conn, "request", p.requestID, "port", portString) + + if err != nil { + msg := fmt.Errorf("error forwarding port %d to pod %s, uid %v: %v", port, h.pod, h.uid, err) + utilruntime.HandleError(msg) + fmt.Fprint(p.errorStream, msg.Error()) + } +} + +// httpStreamPair represents the error and data streams for a port +// forwarding request. +type httpStreamPair struct { + lock sync.RWMutex + requestID string + dataStream httpstream.Stream + errorStream httpstream.Stream + complete chan struct{} +} + +// newPortForwardPair creates a new httpStreamPair. +func newPortForwardPair(requestID string) *httpStreamPair { + return &httpStreamPair{ + requestID: requestID, + complete: make(chan struct{}), + } +} + +// add adds the stream to the httpStreamPair. If the pair already +// contains a stream for the new stream's type, an error is returned. add +// returns true if both the data and error streams for this pair have been +// received. +func (p *httpStreamPair) add(stream httpstream.Stream) (bool, error) { + p.lock.Lock() + defer p.lock.Unlock() + + switch stream.Headers().Get(api.StreamType) { + case api.StreamTypeError: + if p.errorStream != nil { + return false, errors.New("error stream already assigned") + } + p.errorStream = stream + case api.StreamTypeData: + if p.dataStream != nil { + return false, errors.New("data stream already assigned") + } + p.dataStream = stream + } + + complete := p.errorStream != nil && p.dataStream != nil + if complete { + close(p.complete) + } + return complete, nil +} + +// printError writes s to p.errorStream if p.errorStream has been set. +func (p *httpStreamPair) printError(s string) { + p.lock.RLock() + defer p.lock.RUnlock() + if p.errorStream != nil { + fmt.Fprint(p.errorStream, s) + } +} diff --git a/internal/kubernetes/portforward/httpstream_test.go b/internal/kubernetes/portforward/httpstream_test.go new file mode 100644 index 000000000..f594756dc --- /dev/null +++ b/internal/kubernetes/portforward/httpstream_test.go @@ -0,0 +1,267 @@ +/* +Copyright 2016 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package portforward + +import ( + "net/http" + "testing" + "time" + + api "k8s.io/api/core/v1" + "k8s.io/apimachinery/pkg/util/httpstream" +) + +func TestHTTPStreamReceived(t *testing.T) { + tests := map[string]struct { + port string + streamType string + expectedError string + }{ + "missing port": { + expectedError: `"port" header is required`, + }, + "unable to parse port": { + port: "abc", + expectedError: `unable to parse "abc" as a port: strconv.ParseUint: parsing "abc": invalid syntax`, + }, + "negative port": { + port: "-1", + expectedError: `unable to parse "-1" as a port: strconv.ParseUint: parsing "-1": invalid syntax`, + }, + "missing stream type": { + port: "80", + expectedError: `"streamType" header is required`, + }, + "valid port with error stream": { + port: "80", + streamType: "error", + }, + "valid port with data stream": { + port: "80", + streamType: "data", + }, + "invalid stream type": { + port: "80", + streamType: "foo", + expectedError: `invalid stream type "foo"`, + }, + } + for name, test := range tests { + streams := make(chan httpstream.Stream, 1) + f := httpStreamReceived(streams) + stream := newFakeHTTPStream() + if len(test.port) > 0 { + stream.headers.Set("port", test.port) + } + if len(test.streamType) > 0 { + stream.headers.Set("streamType", test.streamType) + } + replySent := make(chan struct{}) + err := f(stream, replySent) + close(replySent) + if len(test.expectedError) > 0 { + if err == nil { + t.Errorf("%s: expected err=%q, but it was nil", name, test.expectedError) + } + if e, a := test.expectedError, err.Error(); e != a { + t.Errorf("%s: expected err=%q, got %q", name, e, a) + } + continue + } + if err != nil { + t.Errorf("%s: unexpected error %v", name, err) + continue + } + if s := <-streams; s != stream { + t.Errorf("%s: expected stream %#v, got %#v", name, stream, s) + } + } +} + +type fakeConn struct { + removeStreamsCalled bool +} + +func (*fakeConn) CreateStream(headers http.Header) (httpstream.Stream, error) { return nil, nil } +func (*fakeConn) Close() error { return nil } +func (*fakeConn) CloseChan() <-chan bool { return nil } +func (*fakeConn) SetIdleTimeout(timeout time.Duration) {} +func (f *fakeConn) RemoveStreams(streams ...httpstream.Stream) { f.removeStreamsCalled = true } + +func TestGetStreamPair(t *testing.T) { + timeout := make(chan time.Time) + + conn := &fakeConn{} + h := &httpStreamHandler{ + streamPairs: make(map[string]*httpStreamPair), + conn: conn, + } + + // test adding a new entry + p, created := h.getStreamPair("1") + if p == nil { + t.Fatalf("unexpected nil pair") + } + if !created { + t.Fatal("expected created=true") + } + if p.dataStream != nil { + t.Errorf("unexpected non-nil data stream") + } + if p.errorStream != nil { + t.Errorf("unexpected non-nil error stream") + } + + // start the monitor for this pair + monitorDone := make(chan struct{}) + go func() { + h.monitorStreamPair(p, timeout) + close(monitorDone) + }() + + if !h.hasStreamPair("1") { + t.Fatal("This should still be true") + } + + // make sure we can retrieve an existing entry + p2, created := h.getStreamPair("1") + if created { + t.Fatal("expected created=false") + } + if p != p2 { + t.Fatalf("retrieving an existing pair: expected %#v, got %#v", p, p2) + } + + // removed via complete + dataStream := newFakeHTTPStream() + dataStream.headers.Set(api.StreamType, api.StreamTypeData) + complete, err := p.add(dataStream) + if err != nil { + t.Fatalf("unexpected error adding data stream to pair: %v", err) + } + if complete { + t.Fatalf("unexpected complete") + } + + errorStream := newFakeHTTPStream() + errorStream.headers.Set(api.StreamType, api.StreamTypeError) + complete, err = p.add(errorStream) + if err != nil { + t.Fatalf("unexpected error adding error stream to pair: %v", err) + } + if !complete { + t.Fatal("unexpected incomplete") + } + + // make sure monitorStreamPair completed + <-monitorDone + + if !conn.removeStreamsCalled { + t.Fatalf("connection remove stream not called") + } + conn.removeStreamsCalled = false + + // make sure the pair was removed + if h.hasStreamPair("1") { + t.Fatal("expected removal of pair after both data and error streams received") + } + + // removed via timeout + p, created = h.getStreamPair("2") + if !created { + t.Fatal("expected created=true") + } + if p == nil { + t.Fatal("expected p not to be nil") + } + + monitorDone = make(chan struct{}) + go func() { + h.monitorStreamPair(p, timeout) + close(monitorDone) + }() + // cause the timeout + close(timeout) + // make sure monitorStreamPair completed + <-monitorDone + if h.hasStreamPair("2") { + t.Fatal("expected stream pair to be removed") + } + if !conn.removeStreamsCalled { + t.Fatalf("connection remove stream not called") + } +} + +func TestRequestID(t *testing.T) { + h := &httpStreamHandler{} + + s := newFakeHTTPStream() + s.headers.Set(api.StreamType, api.StreamTypeError) + s.id = 1 + if e, a := "1", h.requestID(s); e != a { + t.Errorf("expected %q, got %q", e, a) + } + + s.headers.Set(api.StreamType, api.StreamTypeData) + s.id = 3 + if e, a := "1", h.requestID(s); e != a { + t.Errorf("expected %q, got %q", e, a) + } + + s.id = 7 + s.headers.Set(api.PortForwardRequestIDHeader, "2") + if e, a := "2", h.requestID(s); e != a { + t.Errorf("expected %q, got %q", e, a) + } +} + +type fakeHTTPStream struct { + headers http.Header + id uint32 +} + +func newFakeHTTPStream() *fakeHTTPStream { + return &fakeHTTPStream{ + headers: make(http.Header), + } +} + +var _ httpstream.Stream = &fakeHTTPStream{} + +func (s *fakeHTTPStream) Read(data []byte) (int, error) { + return 0, nil +} + +func (s *fakeHTTPStream) Write(data []byte) (int, error) { + return 0, nil +} + +func (s *fakeHTTPStream) Close() error { + return nil +} + +func (s *fakeHTTPStream) Reset() error { + return nil +} + +func (s *fakeHTTPStream) Headers() http.Header { + return s.headers +} + +func (s *fakeHTTPStream) Identifier() uint32 { + return s.id +} diff --git a/internal/kubernetes/portforward/portforward.go b/internal/kubernetes/portforward/portforward.go new file mode 100644 index 000000000..df0fe5a8e --- /dev/null +++ b/internal/kubernetes/portforward/portforward.go @@ -0,0 +1,54 @@ +/* +Copyright 2016 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package portforward + +import ( + "context" + "io" + "net/http" + "time" + + "k8s.io/apimachinery/pkg/types" + "k8s.io/apimachinery/pkg/util/runtime" + "k8s.io/apiserver/pkg/util/wsstream" +) + +// PortForwarder knows how to forward content from a data stream to/from a port +// in a pod. +type PortForwarder interface { + // PortForwarder copies data between a data stream and a port in a pod. + PortForward(ctx context.Context, name string, uid types.UID, port int32, stream io.ReadWriteCloser) error +} + +// ServePortForward handles a port forwarding request. A single request is +// kept alive as long as the client is still alive and the connection has not +// been timed out due to idleness. This function handles multiple forwarded +// connections; i.e., multiple `curl http://localhost:8888/` requests will be +// handled by a single invocation of ServePortForward. +func ServePortForward(w http.ResponseWriter, req *http.Request, portForwarder PortForwarder, podName string, uid types.UID, portForwardOptions *V4Options, idleTimeout time.Duration, streamCreationTimeout time.Duration, supportedProtocols []string) { + var err error + if wsstream.IsWebSocketRequest(req) { + err = handleWebSocketStreams(req, w, portForwarder, podName, uid, portForwardOptions, supportedProtocols, idleTimeout, streamCreationTimeout) + } else { + err = handleHTTPStreams(req, w, portForwarder, podName, uid, supportedProtocols, idleTimeout, streamCreationTimeout) + } + + if err != nil { + runtime.HandleError(err) + return + } +} diff --git a/internal/kubernetes/portforward/websocket.go b/internal/kubernetes/portforward/websocket.go new file mode 100644 index 000000000..cbedb5b6c --- /dev/null +++ b/internal/kubernetes/portforward/websocket.go @@ -0,0 +1,199 @@ +/* +Copyright 2016 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package portforward + +import ( + "context" + "encoding/binary" + "fmt" + "io" + "net/http" + "strconv" + "strings" + "sync" + "time" + + "k8s.io/klog/v2" + + api "k8s.io/api/core/v1" + "k8s.io/apimachinery/pkg/types" + "k8s.io/apimachinery/pkg/util/runtime" + "k8s.io/apiserver/pkg/endpoints/responsewriter" + "k8s.io/apiserver/pkg/util/wsstream" +) + +const ( + dataChannel = iota + errorChannel + + v4BinaryWebsocketProtocol = "v4." + wsstream.ChannelWebSocketProtocol + v4Base64WebsocketProtocol = "v4." + wsstream.Base64ChannelWebSocketProtocol +) + +// V4Options contains details about which streams are required for port +// forwarding. +// All fields included in V4Options need to be expressed explicitly in the +// CRI (k8s.io/cri-api/pkg/apis/{version}/api.proto) PortForwardRequest. +type V4Options struct { + Ports []int32 +} + +// NewV4Options creates a new options from the Request. +func NewV4Options(req *http.Request) (*V4Options, error) { + if !wsstream.IsWebSocketRequest(req) { + return &V4Options{}, nil + } + + portStrings := req.URL.Query()[api.PortHeader] + if len(portStrings) == 0 { + return nil, fmt.Errorf("query parameter %q is required", api.PortHeader) + } + + ports := make([]int32, 0, len(portStrings)) + for _, portString := range portStrings { + if len(portString) == 0 { + return nil, fmt.Errorf("query parameter %q cannot be empty", api.PortHeader) + } + for _, p := range strings.Split(portString, ",") { + port, err := strconv.ParseUint(p, 10, 16) + if err != nil { + return nil, fmt.Errorf("unable to parse %q as a port: %v", portString, err) + } + if port < 1 { + return nil, fmt.Errorf("port %q must be > 0", portString) + } + ports = append(ports, int32(port)) + } + } + + return &V4Options{ + Ports: ports, + }, nil +} + +// BuildV4Options returns a V4Options based on the given information. +func BuildV4Options(ports []int32) (*V4Options, error) { + return &V4Options{Ports: ports}, nil +} + +// handleWebSocketStreams handles requests to forward ports to a pod via +// a PortForwarder. A pair of streams are created per port (DATA n, +// ERROR n+1). The associated port is written to each stream as a unsigned 16 +// bit integer in little endian format. +func handleWebSocketStreams(req *http.Request, w http.ResponseWriter, portForwarder PortForwarder, podName string, uid types.UID, opts *V4Options, supportedPortForwardProtocols []string, idleTimeout, streamCreationTimeout time.Duration) error { + channels := make([]wsstream.ChannelType, 0, len(opts.Ports)*2) + for i := 0; i < len(opts.Ports); i++ { + channels = append(channels, wsstream.ReadWriteChannel, wsstream.WriteChannel) + } + conn := wsstream.NewConn(map[string]wsstream.ChannelProtocolConfig{ + "": { + Binary: true, + Channels: channels, + }, + v4BinaryWebsocketProtocol: { + Binary: true, + Channels: channels, + }, + v4Base64WebsocketProtocol: { + Binary: false, + Channels: channels, + }, + }) + conn.SetIdleTimeout(idleTimeout) + _, streams, err := conn.Open(responsewriter.GetOriginal(w), req) + if err != nil { + err = fmt.Errorf("unable to upgrade websocket connection: %v", err) + return err + } + defer conn.Close() + streamPairs := make([]*websocketStreamPair, len(opts.Ports)) + for i := range streamPairs { + streamPair := websocketStreamPair{ + port: opts.Ports[i], + dataStream: streams[i*2+dataChannel], + errorStream: streams[i*2+errorChannel], + } + streamPairs[i] = &streamPair + + portBytes := make([]byte, 2) + // port is always positive so conversion is allowable + binary.LittleEndian.PutUint16(portBytes, uint16(streamPair.port)) + streamPair.dataStream.Write(portBytes) + streamPair.errorStream.Write(portBytes) + } + h := &websocketStreamHandler{ + conn: conn, + streamPairs: streamPairs, + pod: podName, + uid: uid, + forwarder: portForwarder, + } + h.run() + + return nil +} + +// websocketStreamPair represents the error and data streams for a port +// forwarding request. +type websocketStreamPair struct { + port int32 + dataStream io.ReadWriteCloser + errorStream io.WriteCloser +} + +// websocketStreamHandler is capable of processing a single port forward +// request over a websocket connection +type websocketStreamHandler struct { + conn *wsstream.Conn + streamPairs []*websocketStreamPair + pod string + uid types.UID + forwarder PortForwarder +} + +// run invokes the websocketStreamHandler's forwarder.PortForward +// function for the given stream pair. +func (h *websocketStreamHandler) run() { + wg := sync.WaitGroup{} + wg.Add(len(h.streamPairs)) + + for _, pair := range h.streamPairs { + p := pair + go func() { + defer wg.Done() + h.portForward(p) + }() + } + + wg.Wait() +} + +func (h *websocketStreamHandler) portForward(p *websocketStreamPair) { + ctx := context.Background() + defer p.dataStream.Close() + defer p.errorStream.Close() + + klog.V(5).InfoS("Connection invoking forwarder.PortForward for port", "connection", h.conn, "port", p.port) + err := h.forwarder.PortForward(ctx, h.pod, h.uid, p.port, p.dataStream) + klog.V(5).InfoS("Connection done invoking forwarder.PortForward for port", "connection", h.conn, "port", p.port) + + if err != nil { + msg := fmt.Errorf("error forwarding port %d to pod %s, uid %v: %v", p.port, h.pod, h.uid, err) + runtime.HandleError(msg) + fmt.Fprint(p.errorStream, msg.Error()) + } +} diff --git a/internal/kubernetes/portforward/websocket_test.go b/internal/kubernetes/portforward/websocket_test.go new file mode 100644 index 000000000..8f508d9bd --- /dev/null +++ b/internal/kubernetes/portforward/websocket_test.go @@ -0,0 +1,101 @@ +/* +Copyright 2016 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package portforward + +import ( + "net/http" + "reflect" + "testing" +) + +func TestV4Options(t *testing.T) { + tests := map[string]struct { + url string + websocket bool + expectedOpts *V4Options + expectedError string + }{ + "non-ws request": { + url: "http://example.com", + expectedOpts: &V4Options{}, + }, + "missing port": { + url: "http://example.com", + websocket: true, + expectedError: `query parameter "port" is required`, + }, + "unable to parse port": { + url: "http://example.com?port=abc", + websocket: true, + expectedError: `unable to parse "abc" as a port: strconv.ParseUint: parsing "abc": invalid syntax`, + }, + "negative port": { + url: "http://example.com?port=-1", + websocket: true, + expectedError: `unable to parse "-1" as a port: strconv.ParseUint: parsing "-1": invalid syntax`, + }, + "one port": { + url: "http://example.com?port=80", + websocket: true, + expectedOpts: &V4Options{ + Ports: []int32{80}, + }, + }, + "multiple ports": { + url: "http://example.com?port=80,90,100", + websocket: true, + expectedOpts: &V4Options{ + Ports: []int32{80, 90, 100}, + }, + }, + "multiple port": { + url: "http://example.com?port=80&port=90", + websocket: true, + expectedOpts: &V4Options{ + Ports: []int32{80, 90}, + }, + }, + } + for name, test := range tests { + req, err := http.NewRequest(http.MethodGet, test.url, nil) + if err != nil { + t.Errorf("%s: invalid url %q err=%q", name, test.url, err) + continue + } + if test.websocket { + req.Header.Set("Connection", "Upgrade") + req.Header.Set("Upgrade", "websocket") + } + opts, err := NewV4Options(req) + if len(test.expectedError) > 0 { + if err == nil { + t.Errorf("%s: expected err=%q, but it was nil", name, test.expectedError) + } + if e, a := test.expectedError, err.Error(); e != a { + t.Errorf("%s: expected err=%q, got %q", name, e, a) + } + continue + } + if err != nil { + t.Errorf("%s: unexpected error %v", name, err) + continue + } + if !reflect.DeepEqual(test.expectedOpts, opts) { + t.Errorf("%s: expected options %#v, got %#v", name, test.expectedOpts, err) + } + } +} diff --git a/node/api/portforward.go b/node/api/portforward.go new file mode 100644 index 000000000..d25fcb3a8 --- /dev/null +++ b/node/api/portforward.go @@ -0,0 +1,116 @@ +// Copyright © 2017 The virtual-kubelet authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package api + +import ( + "context" + "io" + "net/http" + "strings" + "time" + + "github.com/gorilla/mux" + "github.com/virtual-kubelet/virtual-kubelet/internal/kubernetes/portforward" + "k8s.io/apimachinery/pkg/types" +) + +// PortForwardHandlerFunc defines the handler function used to +// portforward, passing through the original dataStream +type PortForwardHandlerFunc func(ctx context.Context, namespace, pod string, port int32, stream io.ReadWriteCloser) error + +// PortForwardHandlerConfig is used to pass options to options to the container exec handler. +type PortForwardHandlerConfig struct { + // StreamIdleTimeout is the maximum time a streaming connection + // can be idle before the connection is automatically closed. + StreamIdleTimeout time.Duration + // StreamCreationTimeout is the maximum time for streaming connection + StreamCreationTimeout time.Duration +} + +// PortForwardHandlerOption configures a PortForwardHandlerConfig +// It is used as functional options passed to `HandlePortForward` +type PortForwardHandlerOption func(*PortForwardHandlerConfig) + +// WithPortForwardStreamIdleTimeout sets the idle timeout for a container port forward streaming +func WithPortForwardStreamIdleTimeout(dur time.Duration) PortForwardHandlerOption { + return func(cfg *PortForwardHandlerConfig) { + cfg.StreamIdleTimeout = dur + } +} + +// WithPortForwardCreationTimeout sets the creation timeout for a container exec stream +func WithPortForwardCreationTimeout(dur time.Duration) PortForwardHandlerOption { + return func(cfg *PortForwardHandlerConfig) { + cfg.StreamCreationTimeout = dur + } +} + +// HandlePortForward makes an http handler func from a Provider which forward ports to a container +// Note that this handler currently depends on gorrilla/mux to get url parts as variables. +func HandlePortForward(h PortForwardHandlerFunc, opts ...PortForwardHandlerOption) http.HandlerFunc { + if h == nil { + return NotImplemented + } + + var cfg PortForwardHandlerConfig + for _, o := range opts { + o(&cfg) + } + + if cfg.StreamIdleTimeout == 0 { + cfg.StreamIdleTimeout = 30 * time.Second + } + if cfg.StreamCreationTimeout == 0 { + cfg.StreamCreationTimeout = 30 * time.Second + } + + return handleError(func(w http.ResponseWriter, req *http.Request) error { + vars := mux.Vars(req) + + namespace := vars["namespace"] + + pod := vars["pod"] + + supportedStreamProtocols := strings.Split(req.Header.Get("X-Stream-Protocol-Version"), ",") + + portfwd := &portForwardContext{h: h, pod: pod, namespace: namespace} + portforward.ServePortForward( + w, + req, + portfwd, + pod, + "", + &portforward.V4Options{}, // This is only used for websocket connection + cfg.StreamIdleTimeout, + cfg.StreamCreationTimeout, + supportedStreamProtocols, + ) + + return nil + }) + +} + +type portForwardContext struct { + h PortForwardHandlerFunc + pod string + namespace string +} + +// PortForward Implements portforward.Portforwarder +// This is called by portforward.ServePortForward +func (p *portForwardContext) PortForward(ctx context.Context, name string, uid types.UID, port int32, stream io.ReadWriteCloser) error { + return p.h(ctx, p.namespace, p.pod, port, stream) +} diff --git a/node/api/server.go b/node/api/server.go index daa357bbe..c4ea98eec 100644 --- a/node/api/server.go +++ b/node/api/server.go @@ -36,6 +36,7 @@ type ServeMux interface { type PodHandlerConfig struct { //nolint:golint RunInContainer ContainerExecHandlerFunc AttachToContainer ContainerAttachHandlerFunc + PortForward PortForwardHandlerFunc GetContainerLogs ContainerLogsHandlerFunc // GetPods is meant to enumerate the pods that the provider knows about GetPods PodListerFunc @@ -58,7 +59,6 @@ func PodHandler(p PodHandlerConfig, debug bool) http.Handler { if debug { r.HandleFunc("/runningpods/", HandleRunningPods(p.GetPods)).Methods("GET") } - r.HandleFunc("/pods", HandleRunningPods(p.GetPodsFromKubernetes)).Methods("GET") r.HandleFunc("/containerLogs/{namespace}/{pod}/{container}", HandleContainerLogs(p.GetContainerLogs)).Methods("GET") r.HandleFunc( @@ -77,6 +77,14 @@ func PodHandler(p PodHandlerConfig, debug bool) http.Handler { WithExecStreamIdleTimeout(p.StreamIdleTimeout), ), ).Methods("POST", "GET") + r.HandleFunc( + "/portForward/{namespace}/{pod}", + HandlePortForward( + p.PortForward, + WithPortForwardStreamIdleTimeout(p.StreamCreationTimeout), + WithPortForwardCreationTimeout(p.StreamIdleTimeout), + ), + ).Methods("POST", "GET") if p.GetStatsSummary != nil { f := HandlePodStatsSummary(p.GetStatsSummary) diff --git a/node/nodeutil/provider.go b/node/nodeutil/provider.go index 73aab23a1..cbbc87ada 100644 --- a/node/nodeutil/provider.go +++ b/node/nodeutil/provider.go @@ -37,6 +37,9 @@ type Provider interface { // GetMetricsResource gets the metrics for the node, including running pods GetMetricsResource(context.Context) ([]*dto.MetricFamily, error) + + // PortForward forwards a local port to a port on the pod + PortForward(ctx context.Context, namespace, pod string, port int32, stream io.ReadWriteCloser) error } // ProviderConfig holds objects created by NewNodeFromClient that a provider may need to bootstrap itself. @@ -73,6 +76,7 @@ func AttachProviderRoutes(mux api.ServeMux) NodeOpt { GetMetricsResource: p.GetMetricsResource, StreamIdleTimeout: cfg.StreamIdleTimeout, StreamCreationTimeout: cfg.StreamCreationTimeout, + PortForward: p.PortForward, }, true)) } return nil