beam-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From al...@apache.org
Subject [beam] branch master updated: [BEAM-3306] Add first pass coder registry for Go
Date Tue, 05 Feb 2019 00:30:22 GMT
This is an automated email from the ASF dual-hosted git repository.

altay pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new 2e34afa  [BEAM-3306] Add first pass coder registry for Go
     new 727b853  Merge pull request #7723 from lostluck/userdefinedcoder
2e34afa is described below

commit 2e34afa538b6840204ee2e0b991b21a723e2579c
Author: Robert Burke <robert@frantil.com>
AuthorDate: Mon Feb 4 20:41:27 2019 +0000

    [BEAM-3306] Add first pass coder registry for Go
---
 sdks/go/pkg/beam/beam.shims.go                     |   8 +-
 sdks/go/pkg/beam/coder.go                          |  56 +-
 sdks/go/pkg/beam/coder_test.go                     |   7 +-
 sdks/go/pkg/beam/core/funcx/signature.go           |   3 +
 sdks/go/pkg/beam/core/graph/coder/coder.go         |  35 +-
 sdks/go/pkg/beam/core/graph/coder/coder_test.go    | 564 +++++++++++++++++++++
 sdks/go/pkg/beam/core/graph/coder/registry.go      | 102 ++++
 sdks/go/pkg/beam/core/graph/coder/registry_test.go | 322 ++++++++++++
 sdks/go/pkg/beam/core/typex/class.go               |  65 ++-
 sdks/go/pkg/beam/core/typex/class_test.go          |  68 ++-
 sdks/go/pkg/beam/core/typex/fulltype.go            |   3 +-
 sdks/go/pkg/beam/core/typex/special.go             |   3 +-
 sdks/go/pkg/beam/forward.go                        |  42 ++
 sdks/go/pkg/beam/util.go                           |   2 +-
 14 files changed, 1213 insertions(+), 67 deletions(-)

diff --git a/sdks/go/pkg/beam/beam.shims.go b/sdks/go/pkg/beam/beam.shims.go
index de4c4ca..8c95289 100644
--- a/sdks/go/pkg/beam/beam.shims.go
+++ b/sdks/go/pkg/beam/beam.shims.go
@@ -30,15 +30,15 @@ import (
 )
 
 func init() {
-	runtime.RegisterFunction(JSONDec)
-	runtime.RegisterFunction(JSONEnc)
-	runtime.RegisterFunction(ProtoDec)
-	runtime.RegisterFunction(ProtoEnc)
 	runtime.RegisterFunction(addFixedKeyFn)
 	runtime.RegisterFunction(dropKeyFn)
 	runtime.RegisterFunction(dropValueFn)
 	runtime.RegisterFunction(explodeFn)
+	runtime.RegisterFunction(jsonDec)
+	runtime.RegisterFunction(jsonEnc)
 	runtime.RegisterFunction(makePartitionFn)
+	runtime.RegisterFunction(protoDec)
+	runtime.RegisterFunction(protoEnc)
 	runtime.RegisterFunction(swapKVFn)
 	runtime.RegisterType(reflect.TypeOf((*createFn)(nil)).Elem())
 	runtime.RegisterType(reflect.TypeOf((*reflect.Type)(nil)).Elem())
diff --git a/sdks/go/pkg/beam/coder.go b/sdks/go/pkg/beam/coder.go
index f4943bd..83b934f 100644
--- a/sdks/go/pkg/beam/coder.go
+++ b/sdks/go/pkg/beam/coder.go
@@ -27,6 +27,18 @@ import (
 	"github.com/golang/protobuf/proto"
 )
 
+type jsonCoder interface {
+	json.Marshaler
+	json.Unmarshaler
+}
+
+var protoMessageType = reflect.TypeOf((*proto.Message)(nil)).Elem()
+var jsonCoderType = reflect.TypeOf((*jsonCoder)(nil)).Elem()
+
+func init() {
+	coder.RegisterCoder(protoMessageType, protoEnc, protoDec)
+}
+
 // Coder defines how to encode and decode values of type 'A' into byte streams.
 // Coders are attached to PCollections of the same type. For PCollections
 // consumed by GBK, the attached coders are required to be deterministic.
@@ -81,8 +93,6 @@ func NewCoder(t FullType) Coder {
 	return Coder{c}
 }
 
-var protoMessageType = reflect.TypeOf((*proto.Message)(nil)).Elem()
-
 func inferCoder(t FullType) (*coder.Coder, error) {
 	switch t.Class() {
 	case typex.Concrete, typex.Container:
@@ -115,17 +125,17 @@ func inferCoder(t FullType) (*coder.Coder, error) {
 		case reflectx.ByteSlice:
 			return &coder.Coder{Kind: coder.Bytes, T: t}, nil
 		default:
-			// TODO(BEAM-3306): the coder registry should be consulted here for user
-			// specified types and their coders.
-			if t.Type().Implements(protoMessageType) {
-				c, err := newProtoCoder(t.Type())
-				if err != nil {
-					return nil, err
-				}
+			et := t.Type()
+			if c := coder.LookupCustomCoder(et); c != nil {
 				return &coder.Coder{Kind: coder.Custom, T: t, Custom: c}, nil
 			}
+			// Interface types that implement JSON marshalling can be handled by the default coder.
+			// otherwise, inference needs to fail here.
+			if et.Kind() == reflect.Interface && !et.Implements(jsonCoderType) {
+				return nil, fmt.Errorf("inferCoder failed: interface type %v has no coder registered", et)
+			}
 
-			c, err := newJSONCoder(t.Type())
+			c, err := newJSONCoder(et)
 			if err != nil {
 				return nil, err
 			}
@@ -174,14 +184,14 @@ func inferCoders(list []FullType) ([]*coder.Coder, error) {
 // form that doesn't require LengthPrefix'ing to cut up the bytestream from
 // the FnHarness.
 
-// ProtoEnc marshals the supplied proto.Message.
-func ProtoEnc(in T) ([]byte, error) {
+// protoEnc marshals the supplied proto.Message.
+func protoEnc(in T) ([]byte, error) {
 	return proto.Marshal(in.(proto.Message))
 }
 
-// ProtoDec unmarshals the supplied bytes into an instance of the supplied
+// protoDec unmarshals the supplied bytes into an instance of the supplied
 // proto.Message type.
-func ProtoDec(t reflect.Type, in []byte) (T, error) {
+func protoDec(t reflect.Type, in []byte) (T, error) {
 	val := reflect.New(t.Elem()).Interface().(proto.Message)
 	if err := proto.Unmarshal(in, val); err != nil {
 		return nil, err
@@ -189,24 +199,16 @@ func ProtoDec(t reflect.Type, in []byte) (T, error) {
 	return val, nil
 }
 
-func newProtoCoder(t reflect.Type) (*coder.CustomCoder, error) {
-	c, err := coder.NewCustomCoder("proto", t, ProtoEnc, ProtoDec)
-	if err != nil {
-		return nil, fmt.Errorf("invalid coder: %v", err)
-	}
-	return c, nil
-}
-
 // Concrete and universal custom coders both have a similar signature.
 // Conversion is handled by reflection.
 
-// JSONEnc encodes the supplied value in JSON.
-func JSONEnc(in T) ([]byte, error) {
+// jsonEnc encodes the supplied value in JSON.
+func jsonEnc(in T) ([]byte, error) {
 	return json.Marshal(in)
 }
 
-// JSONDec decodes the supplied JSON into an instance of the supplied type.
-func JSONDec(t reflect.Type, in []byte) (T, error) {
+// jsonDec decodes the supplied JSON into an instance of the supplied type.
+func jsonDec(t reflect.Type, in []byte) (T, error) {
 	val := reflect.New(t)
 	if err := json.Unmarshal(in, val.Interface()); err != nil {
 		return nil, err
@@ -215,7 +217,7 @@ func JSONDec(t reflect.Type, in []byte) (T, error) {
 }
 
 func newJSONCoder(t reflect.Type) (*coder.CustomCoder, error) {
-	c, err := coder.NewCustomCoder("json", t, JSONEnc, JSONDec)
+	c, err := coder.NewCustomCoder("json", t, jsonEnc, jsonDec)
 	if err != nil {
 		return nil, fmt.Errorf("invalid coder: %v", err)
 	}
diff --git a/sdks/go/pkg/beam/coder_test.go b/sdks/go/pkg/beam/coder_test.go
index b789ca8..6bd3849 100644
--- a/sdks/go/pkg/beam/coder_test.go
+++ b/sdks/go/pkg/beam/coder_test.go
@@ -13,12 +13,11 @@
 // See the License for the specific language governing permissions and
 // limitations under the License.
 
-package beam_test
+package beam
 
 import (
 	"testing"
 
-	"github.com/apache/beam/sdks/go/pkg/beam"
 	"github.com/apache/beam/sdks/go/pkg/beam/core/util/reflectx"
 )
 
@@ -26,11 +25,11 @@ func TestJSONCoder(t *testing.T) {
 	tests := []int{43, 12431235, -2, 0, 1}
 
 	for _, test := range tests {
-		data, err := beam.JSONEnc(test)
+		data, err := jsonEnc(test)
 		if err != nil {
 			t.Fatalf("Failed to encode %v: %v", tests, err)
 		}
-		decoded, err := beam.JSONDec(reflectx.Int, data)
+		decoded, err := jsonDec(reflectx.Int, data)
 		if err != nil {
 			t.Fatalf("Failed to decode: %v", err)
 		}
diff --git a/sdks/go/pkg/beam/core/funcx/signature.go b/sdks/go/pkg/beam/core/funcx/signature.go
index e33df8f..fbb0ad6 100644
--- a/sdks/go/pkg/beam/core/funcx/signature.go
+++ b/sdks/go/pkg/beam/core/funcx/signature.go
@@ -175,6 +175,9 @@ func matchReq(list, models []reflect.Type) error {
 		}
 
 		model := models[i]
+		if t.Kind() == reflect.Interface && model.Implements(t) {
+			continue
+		}
 		if model != t {
 			return fmt.Errorf("type mismatch: %v, want %v", t, model)
 		}
diff --git a/sdks/go/pkg/beam/core/graph/coder/coder.go b/sdks/go/pkg/beam/core/graph/coder/coder.go
index b3e5dcb..7432c10 100644
--- a/sdks/go/pkg/beam/core/graph/coder/coder.go
+++ b/sdks/go/pkg/beam/core/graph/coder/coder.go
@@ -55,6 +55,12 @@ type CustomCoder struct {
 // Equals returns true iff the two custom coders are equal. It assumes that
 // functions with the same name and types are identical.
 func (c *CustomCoder) Equals(o *CustomCoder) bool {
+	if c == nil && o == nil {
+		return true
+	}
+	if c == nil && o != nil || c != nil && o == nil {
+		return false
+	}
 	if c.Name != o.Name {
 		return false
 	}
@@ -86,24 +92,43 @@ var (
 		OptReturn: []reflect.Type{reflectx.Error}}
 )
 
+func validateEncoder(t reflect.Type, encode interface{}) error {
+	// Check if it uses the real type in question.
+	if err := funcx.Satisfy(encode, funcx.Replace(encodeSig, typex.TType, t)); err != nil {
+		return fmt.Errorf("validateEncoder: incorrect signature: %v", err)
+	}
+	// TODO(lostluck): 2019.02.03 - Determine if there are encode allocation bottlenecks.
+	return nil
+}
+
+func validateDecoder(t reflect.Type, decode interface{}) error {
+	// Check if it uses the real type in question.
+	if err := funcx.Satisfy(decode, funcx.Replace(decodeSig, typex.TType, t)); err != nil {
+		return fmt.Errorf("validateDecoder: incorrect signature: %v", err)
+	}
+	// TODO(lostluck): 2019.02.03 - Expand cases to avoid []byte -> interface{} conversion
+	// in exec, & a beam Decoder interface.
+	return nil
+}
+
 // NewCustomCoder creates a coder for the supplied parameters defining a
 // particular encoding strategy.
 func NewCustomCoder(id string, t reflect.Type, encode, decode interface{}) (*CustomCoder, error) {
+	if err := validateEncoder(t, encode); err != nil {
+		return nil, fmt.Errorf("NewCustomCoder: %v", err)
+	}
 	enc, err := funcx.New(reflectx.MakeFunc(encode))
 	if err != nil {
 		return nil, fmt.Errorf("bad encode: %v", err)
 	}
-	if err := funcx.Satisfy(encode, funcx.Replace(encodeSig, typex.TType, t)); err != nil {
-		return nil, fmt.Errorf("encode has incorrect signature: %v", err)
+	if err := validateDecoder(t, decode); err != nil {
+		return nil, fmt.Errorf("NewCustomCoder: %v", err)
 	}
 
 	dec, err := funcx.New(reflectx.MakeFunc(decode))
 	if err != nil {
 		return nil, fmt.Errorf("bad decode: %v", err)
 	}
-	if err := funcx.Satisfy(decode, funcx.Replace(decodeSig, typex.TType, t)); err != nil {
-		return nil, fmt.Errorf("decode has incorrect signature: %v", err)
-	}
 
 	c := &CustomCoder{
 		Name: id,
diff --git a/sdks/go/pkg/beam/core/graph/coder/coder_test.go b/sdks/go/pkg/beam/core/graph/coder/coder_test.go
new file mode 100644
index 0000000..bfc9cee
--- /dev/null
+++ b/sdks/go/pkg/beam/core/graph/coder/coder_test.go
@@ -0,0 +1,564 @@
+// Licensed to the Apache Software Foundation (ASF) under one or more
+// contributor license agreements.  See the NOTICE file distributed with
+// this work for additional information regarding copyright ownership.
+// The ASF licenses this file to You under the Apache License, Version 2.0
+// (the "License"); you may not use this file except in compliance with
+// the License.  You may obtain a copy of the License at
+//
+//    http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package coder
+
+import (
+	"fmt"
+	"reflect"
+	"testing"
+
+	"github.com/apache/beam/sdks/go/pkg/beam/core/util/reflectx"
+
+	"github.com/apache/beam/sdks/go/pkg/beam/core/typex"
+)
+
+type MyType struct{}
+
+func (MyType) A()  {}
+func (*MyType) B() {}
+
+type a interface {
+	A()
+}
+type b interface {
+	A()
+}
+
+var (
+	mPtrT = reflect.TypeOf((*MyType)(nil))
+	mT    = mPtrT.Elem()
+)
+
+func TestValidEncoderForms(t *testing.T) {
+	tests := []struct {
+		t   reflect.Type
+		enc interface{}
+	}{
+		{t: mT, enc: func(MyType) []byte { return nil }},
+		{t: mT, enc: func(MyType) ([]byte, error) { return nil, nil }},
+		{t: mT, enc: func(reflect.Type, MyType) []byte { return nil }},
+		{t: mT, enc: func(reflect.Type, MyType) ([]byte, error) { return nil, nil }},
+		// Using a universal type as encode type.
+		{t: mT, enc: func(typex.T) []byte { return nil }},
+		{t: mT, enc: func(typex.T) ([]byte, error) { return nil, nil }},
+		{t: mT, enc: func(reflect.Type, typex.T) []byte { return nil }},
+		{t: mT, enc: func(reflect.Type, typex.T) ([]byte, error) { return nil, nil }},
+		// Using a satisfied interface type as encode type.
+		{t: mT, enc: func(a) []byte { return nil }},
+		{t: mT, enc: func(a) ([]byte, error) { return nil, nil }},
+		{t: mT, enc: func(reflect.Type, a) []byte { return nil }},
+		{t: mT, enc: func(reflect.Type, a) ([]byte, error) { return nil, nil }},
+
+		{t: mPtrT, enc: func(a) []byte { return nil }},
+		{t: mPtrT, enc: func(a) ([]byte, error) { return nil, nil }},
+		{t: mPtrT, enc: func(reflect.Type, a) []byte { return nil }},
+		{t: mPtrT, enc: func(reflect.Type, a) ([]byte, error) { return nil, nil }},
+
+		{t: mPtrT, enc: func(b) []byte { return nil }},
+		{t: mPtrT, enc: func(b) ([]byte, error) { return nil, nil }},
+		{t: mPtrT, enc: func(reflect.Type, b) []byte { return nil }},
+		{t: mPtrT, enc: func(reflect.Type, b) ([]byte, error) { return nil, nil }},
+	}
+	for _, test := range tests {
+		test := test
+		t.Run(fmt.Sprintf("%T", test.enc), func(t *testing.T) {
+			if err := validateEncoder(mT, test.enc); err != nil {
+				t.Fatal(err)
+			}
+		})
+	}
+}
+func TestValidDecoderForms(t *testing.T) {
+	tests := []struct {
+		t   reflect.Type
+		dec interface{}
+	}{
+		{t: mT, dec: func([]byte) MyType { return MyType{} }},
+		{t: mT, dec: func([]byte) (MyType, error) { return MyType{}, nil }},
+		{t: mT, dec: func(reflect.Type, []byte) MyType { return MyType{} }},
+		{t: mT, dec: func(reflect.Type, []byte) (MyType, error) { return MyType{}, nil }},
+
+		// Using a universal type as decode type.
+		{t: mT, dec: func([]byte) typex.T { return MyType{} }},
+		{t: mT, dec: func([]byte) (typex.T, error) { return MyType{}, nil }},
+		{t: mT, dec: func(reflect.Type, []byte) typex.T { return MyType{} }},
+		{t: mT, dec: func(reflect.Type, []byte) (typex.T, error) { return MyType{}, nil }},
+
+		// Using satisfied interfaces as decode type.
+		{t: mT, dec: func([]byte) a { return nil }},
+		{t: mT, dec: func([]byte) (a, error) { return nil, nil }},
+		{t: mT, dec: func(reflect.Type, []byte) a { return nil }},
+		{t: mT, dec: func(reflect.Type, []byte) (a, error) { return nil, nil }},
+
+		{t: mPtrT, dec: func([]byte) a { return nil }},
+		{t: mPtrT, dec: func([]byte) (a, error) { return nil, nil }},
+		{t: mPtrT, dec: func(reflect.Type, []byte) a { return nil }},
+		{t: mPtrT, dec: func(reflect.Type, []byte) (a, error) { return nil, nil }},
+
+		{t: mPtrT, dec: func([]byte) b { return nil }},
+		{t: mPtrT, dec: func([]byte) (b, error) { return nil, nil }},
+		{t: mPtrT, dec: func(reflect.Type, []byte) b { return nil }},
+		{t: mPtrT, dec: func(reflect.Type, []byte) (b, error) { return nil, nil }},
+	}
+	for _, test := range tests {
+		test := test
+		t.Run(fmt.Sprintf("%T", test.dec), func(t *testing.T) {
+			if err := validateDecoder(test.t, test.dec); err != nil {
+				t.Fatal(err)
+			}
+		})
+	}
+}
+
+func TestCoder_String(t *testing.T) {
+	ints := NewVarInt()
+	bytes := NewBytes()
+	global := NewGlobalWindow()
+	interval := NewIntervalWindow()
+	cusString, err := NewCustomCoder("customString", reflectx.String, func(string) []byte { return nil }, func([]byte) string { return "" })
+	if err != nil {
+		t.Fatal(err)
+	}
+	custom := &Coder{Kind: Custom, Custom: cusString, T: typex.New(reflectx.String)}
+
+	tests := []struct {
+		want string
+		c    *Coder
+	}{{
+		want: "$",
+		c:    nil,
+	}, {
+		want: "bytes",
+		c:    bytes,
+	}, {
+		want: "varint",
+		c:    ints,
+	}, {
+		want: "string[customString]",
+		c:    custom,
+	}, {
+		want: "W<varint>!GWC",
+		c:    NewW(ints, global),
+	}, {
+		want: "W<bytes>!IWC",
+		c:    NewW(bytes, interval),
+	}, {
+		want: "KV<bytes,varint>",
+		c:    NewKV([]*Coder{bytes, ints}),
+	}, {
+		want: "CoGBK<bytes,varint,bytes>",
+		c:    NewCoGBK([]*Coder{bytes, ints, bytes}),
+	}, {
+		want: "W<KV<bytes,varint>>!IWC",
+		c:    NewW(NewKV([]*Coder{bytes, ints}), interval),
+	}, {
+		want: "CoGBK<bytes,varint,string[customString]>",
+		c:    NewCoGBK([]*Coder{bytes, ints, custom}),
+	},
+	}
+	for _, test := range tests {
+		test := test
+		t.Run(test.want, func(t *testing.T) {
+			got := test.c.String()
+			if test.want != got {
+				t.Fatalf("c.String() = %v, want %v", got, test.want)
+			}
+		})
+	}
+}
+
+func TestCoder_Equals(t *testing.T) {
+	ints := NewVarInt()
+	bytes := NewBytes()
+	global := NewGlobalWindow()
+	interval := NewIntervalWindow()
+
+	cusStrEnc := func(string) []byte { return nil }
+	cusStrDec := func([]byte) string { return "" }
+
+	cusString1, err := NewCustomCoder("cus1", reflectx.String, cusStrEnc, cusStrDec)
+	if err != nil {
+		t.Fatal(err)
+	}
+	custom1 := &Coder{Kind: Custom, Custom: cusString1, T: typex.New(reflectx.String)}
+	cusString2, err := NewCustomCoder("cus2", reflectx.String, cusStrEnc, cusStrDec)
+	if err != nil {
+		t.Fatal(err)
+	}
+	custom2 := &Coder{Kind: Custom, Custom: cusString2, T: typex.New(reflectx.String)}
+	cusSameAs1, err := NewCustomCoder("cus1", reflectx.String, cusStrEnc, cusStrDec)
+	if err != nil {
+		t.Fatal(err)
+	}
+	customSame := &Coder{Kind: Custom, Custom: cusSameAs1, T: typex.New(reflectx.String)}
+
+	tests := []struct {
+		want bool
+		a, b *Coder
+	}{{
+		want: true,
+		a:    bytes,
+		b:    bytes,
+	}, {
+		want: true,
+		a:    ints,
+		b:    ints,
+	}, {
+		want: false,
+		a:    ints,
+		b:    bytes,
+	}, {
+		want: true,
+		a:    custom1,
+		b:    custom1,
+	}, {
+		want: false,
+		a:    custom1,
+		b:    custom2,
+	}, {
+		want: true,
+		a:    custom1,
+		b:    customSame,
+	}, {
+		want: true,
+		a:    NewW(ints, global),
+		b:    NewW(ints, global),
+	}, {
+		want: false,
+		a:    NewW(ints, global),
+		b:    NewW(ints, interval),
+	}, {
+		want: false,
+		a:    NewW(bytes, global),
+		b:    NewW(ints, global),
+	}, {
+		want: true,
+		a:    NewW(custom1, interval),
+		b:    NewW(customSame, interval),
+	}, {
+		want: true,
+		a:    NewKV([]*Coder{custom1, ints}),
+		b:    NewKV([]*Coder{customSame, ints}),
+	}, {
+		want: true,
+		a:    NewCoGBK([]*Coder{custom1, ints, customSame}),
+		b:    NewCoGBK([]*Coder{customSame, ints, custom1}),
+	}, {
+		want: false,
+		a:    NewCoGBK([]*Coder{custom1, ints, customSame}),
+		b:    NewCoGBK([]*Coder{customSame, ints, custom2}),
+	},
+	}
+	for _, test := range tests {
+		test := test
+		t.Run(fmt.Sprintf("%v_vs_%v", test.a, test.b), func(t *testing.T) {
+			if got := test.a.Equals(test.b); test.want != got {
+				t.Errorf("A vs B: %v.Equals(%v) = %v, want %v", test.a, test.b, got, test.want)
+			}
+			if got := test.b.Equals(test.a); test.want != got {
+				t.Errorf("B vs A: %v.Equals(%v) = %v, want %v", test.b, test.a, got, test.want)
+			}
+		})
+	}
+}
+
+func TestCustomCoder_Equals(t *testing.T) {
+	cusStrEnc := func(string) []byte { return nil }
+	cusStrDec := func([]byte) string { return "" }
+
+	cusStrEnc2 := func(string) []byte { return nil }
+	cusStrDec2 := func([]byte) string { return "" }
+
+	newCC := func(t *testing.T, name string, et reflect.Type, enc, dec interface{}) *CustomCoder {
+		t.Helper()
+		cc, err := NewCustomCoder(name, et, enc, dec)
+		if err != nil {
+			t.Fatalf("couldn't get custom coder for %v: %v", et, err)
+		}
+		return cc
+	}
+
+	cusBase := newCC(t, "cus1", reflectx.String, cusStrEnc, cusStrDec)
+	cusSame := newCC(t, "cus1", reflectx.String, cusStrEnc, cusStrDec)
+	cusDiffName := newCC(t, "cus2", reflectx.String, cusStrEnc, cusStrDec)
+	cusDiffType := newCC(t, "cus1", reflectx.Int, func(int) []byte { return nil }, func([]byte) int { return 0 })
+	cusDiffEnc := newCC(t, "cus1", reflectx.String, cusStrEnc2, cusStrDec)
+	cusDiffDec := newCC(t, "cus1", reflectx.String, cusStrEnc, cusStrDec2)
+
+	tests := []struct {
+		name string
+		want bool
+		a, b *CustomCoder
+	}{{
+		name: "nils",
+		want: true,
+		a:    nil,
+		b:    nil,
+	}, {
+		name: "baseVsNil",
+		want: false,
+		a:    cusBase,
+		b:    nil,
+	}, {
+		name: "sameInstance",
+		want: true,
+		a:    cusBase,
+		b:    cusBase,
+	}, {
+		name: "identical",
+		want: true,
+		a:    cusBase,
+		b:    cusSame,
+	}, {
+		name: "diffType",
+		want: false,
+		a:    cusBase,
+		b:    cusDiffType,
+	}, {
+		name: "diffEnc",
+		want: false,
+		a:    cusBase,
+		b:    cusDiffEnc,
+	}, {
+		name: "diffDec",
+		want: false,
+		a:    cusBase,
+		b:    cusDiffDec,
+	}, {
+		name: "diffName",
+		want: false,
+		a:    cusBase,
+		b:    cusDiffName,
+	},
+	}
+	for _, test := range tests {
+		test := test
+		t.Run(fmt.Sprintf("%v_vs_%v", test.a, test.b), func(t *testing.T) {
+			if got := test.a.Equals(test.b); test.want != got {
+				t.Errorf("A vs B: %v.Equals(%v) = %v, want %v", test.a, test.b, got, test.want)
+			}
+			if got := test.b.Equals(test.a); test.want != got {
+				t.Fatalf("B vs A: %v.Equals(%v) = %v, want %v", test.b, test.a, got, test.want)
+			}
+		})
+	}
+}
+
+func TestSkipW(t *testing.T) {
+	want := NewBytes()
+	t.Run("unwindowed", func(t *testing.T) {
+		if got := SkipW(want); !want.Equals(got) {
+			t.Fatalf("SkipW(%v) = %v, want %v", want, got, want)
+		}
+	})
+	t.Run("windowed", func(t *testing.T) {
+		in := NewW(want, NewGlobalWindow())
+		if got := SkipW(in); !want.Equals(got) {
+			t.Fatalf("SkipW(%v) = %v, want %v", in, got, want)
+		}
+	})
+}
+
+func TestNewW(t *testing.T) {
+	global := NewGlobalWindow()
+	bytes := NewBytes()
+
+	tests := []struct {
+		name        string
+		c           *Coder
+		w           *WindowCoder
+		shouldpanic bool
+		want        *Coder
+	}{{
+		name: "nil_nil",
+		c:    nil, w: nil, shouldpanic: true,
+	}, {
+		name: "nil_global",
+		c:    nil, w: global, shouldpanic: true,
+	}, {
+		name: "bytes_nil",
+		c:    bytes, w: nil, shouldpanic: true,
+	}, {
+		name: "bytes_global",
+		c:    bytes, w: global,
+		want: &Coder{Kind: WindowedValue, T: typex.NewW(bytes.T), Window: global, Components: []*Coder{bytes}},
+	},
+	}
+
+	for _, test := range tests {
+		test := test
+		t.Run(test.name, func(t *testing.T) {
+			if test.shouldpanic {
+				defer func() {
+					if p := recover(); p != nil {
+						t.Log(p)
+						return
+					}
+					t.Fatalf("NewW(%v, %v): want panic", test.c, test.w)
+				}()
+			}
+			got := NewW(test.c, test.w)
+			if !IsW(got) {
+				t.Errorf("IsW(%v) = false, want true", got)
+			}
+			if test.want != nil && !test.want.Equals(got) {
+				t.Fatalf("NewW(%v, %v) = %v, want %v", test.c, test.w, got, test.want)
+			}
+		})
+	}
+}
+
+func TestNewKV(t *testing.T) {
+	bytes := NewBytes()
+	ints := NewVarInt()
+
+	tests := []struct {
+		name        string
+		cs          []*Coder
+		shouldpanic bool
+		want        *Coder
+	}{{
+		name:        "nil",
+		cs:          nil,
+		shouldpanic: true,
+	}, {
+		name:        "empty",
+		cs:          []*Coder{},
+		shouldpanic: true,
+	}, {
+		name:        "bytes",
+		cs:          []*Coder{bytes},
+		shouldpanic: true,
+	}, {
+		name:        "bytes_nil",
+		cs:          []*Coder{bytes, nil},
+		shouldpanic: true,
+	}, {
+		name:        "nil_ints",
+		cs:          []*Coder{nil, ints},
+		shouldpanic: true,
+	}, {
+		name: "bytes_ints",
+		cs:   []*Coder{bytes, ints},
+		want: &Coder{Kind: KV, T: typex.NewKV(bytes.T, ints.T), Components: []*Coder{bytes, ints}},
+	}, {
+		name: "ints_bytes",
+		cs:   []*Coder{ints, bytes},
+		want: &Coder{Kind: KV, T: typex.NewKV(ints.T, bytes.T), Components: []*Coder{ints, bytes}},
+	}, {
+		name:        "ints_bytes_bytes",
+		cs:          []*Coder{ints, bytes, bytes},
+		shouldpanic: true,
+	}, {
+		name:        "ints_nil_bytes",
+		cs:          []*Coder{ints, nil, bytes},
+		shouldpanic: true,
+	},
+	}
+
+	for _, test := range tests {
+		test := test
+		t.Run(test.name, func(t *testing.T) {
+			if test.shouldpanic {
+				defer func() {
+					if p := recover(); p != nil {
+						t.Log(p)
+						return
+					}
+					t.Fatalf("NewKV(%v): want panic", test.cs)
+				}()
+			}
+			got := NewKV(test.cs)
+			if !IsKV(got) {
+				t.Errorf("IsKV(%v) = false, want true", got)
+			}
+			if test.want != nil && !test.want.Equals(got) {
+				t.Fatalf("NewKV(%v) = %v, want %v", test.cs, got, test.want)
+			}
+		})
+	}
+}
+
+func TestNewCoGBK(t *testing.T) {
+	bytes := NewBytes()
+	ints := NewVarInt()
+
+	tests := []struct {
+		name        string
+		cs          []*Coder
+		shouldpanic bool
+		want        *Coder
+	}{{
+		name:        "nil",
+		cs:          nil,
+		shouldpanic: true,
+	}, {
+		name:        "empty",
+		cs:          []*Coder{},
+		shouldpanic: true,
+	}, {
+		name:        "bytes",
+		cs:          []*Coder{bytes},
+		shouldpanic: true,
+	}, {
+		name: "bytes_ints",
+		cs:   []*Coder{bytes, ints},
+		want: &Coder{Kind: CoGBK, T: typex.NewCoGBK(bytes.T, ints.T), Components: []*Coder{bytes, ints}},
+	}, {
+		name: "ints_bytes",
+		cs:   []*Coder{ints, bytes},
+		want: &Coder{Kind: CoGBK, T: typex.NewCoGBK(ints.T, bytes.T), Components: []*Coder{ints, bytes}},
+	}, {
+		name:        "ints_nil",
+		cs:          []*Coder{ints, nil},
+		shouldpanic: true,
+	}, {
+		name:        "nil_ints",
+		cs:          []*Coder{nil, ints},
+		shouldpanic: true,
+	}, {
+		name: "ints_bytes_bytes_ints_bytes,",
+		cs:   []*Coder{ints, bytes, bytes, ints, bytes},
+		want: &Coder{Kind: CoGBK, T: typex.NewCoGBK(ints.T, bytes.T, bytes.T, ints.T, bytes.T), Components: []*Coder{ints, bytes, bytes, ints, bytes}},
+	}, {
+		name:        "ints_bytes_bytes_ints_nil_bytes,",
+		cs:          []*Coder{ints, bytes, bytes, ints, nil, bytes},
+		shouldpanic: true,
+	},
+	}
+
+	for _, test := range tests {
+		test := test
+		t.Run(test.name, func(t *testing.T) {
+			if test.shouldpanic {
+				defer func() {
+					if p := recover(); p != nil {
+						t.Log(p)
+						return
+					}
+					t.Fatalf("NewCoGBK(%v): want panic", test.cs)
+				}()
+			}
+			got := NewCoGBK(test.cs)
+			if !IsCoGBK(got) {
+				t.Errorf("IsCoGBK(%v) = false, want true", got)
+			}
+			if test.want != nil && !test.want.Equals(got) {
+				t.Fatalf("NewCoGBK(%v) = %v, want %v", test.cs, got, test.want)
+			}
+		})
+	}
+}
diff --git a/sdks/go/pkg/beam/core/graph/coder/registry.go b/sdks/go/pkg/beam/core/graph/coder/registry.go
new file mode 100644
index 0000000..d95258d
--- /dev/null
+++ b/sdks/go/pkg/beam/core/graph/coder/registry.go
@@ -0,0 +1,102 @@
+// Licensed to the Apache Software Foundation (ASF) under one or more
+// contributor license agreements.  See the NOTICE file distributed with
+// this work for additional information regarding copyright ownership.
+// The ASF licenses this file to You under the Apache License, Version 2.0
+// (the "License"); you may not use this file except in compliance with
+// the License.  You may obtain a copy of the License at
+//
+//    http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package coder
+
+import (
+	"fmt"
+	"reflect"
+)
+
+var (
+	coderRegistry     = make(map[uintptr]func(reflect.Type) *CustomCoder)
+	interfaceOrdering []reflect.Type
+)
+
+// RegisterCoder registers a user defined coder for a given type, and will
+// be used if there is no beam coder for that type. Must be called prior to beam.Init(),
+// preferably in an init() function.
+//
+// Coders are encoder and decoder pairs, and operate around []bytes.
+//
+// The coder used for a given type follows this ordering:
+//   1. Coders for Known Beam types.
+//   2. Coders registered for specific types
+//   3. Coders registered for interfaces types
+//   4. Default coder (JSON)
+//
+// Types of kind Interface, are handled specially by the registry, so they may be iterated
+// over to check if element types implement them.
+//
+// Repeated registrations of the same type overrides prior ones.
+func RegisterCoder(t reflect.Type, enc, dec interface{}) {
+	key := tkey(t)
+
+	if _, err := NewCustomCoder(t.String(), t, enc, dec); err != nil {
+		panic(fmt.Sprintf("RegisterCoder failed for type %v: %v", t, err))
+	}
+
+	if t.Kind() == reflect.Interface {
+		// If it's already in the registry, then it's already in the list
+		// and should be removed.
+		if _, ok := coderRegistry[key]; ok {
+			var index int
+			for i, iT := range interfaceOrdering {
+				iKey := tkey(iT)
+				if iKey == key {
+					index = i
+					break
+				}
+			}
+			interfaceOrdering = append(interfaceOrdering[:index], interfaceOrdering[index+1:]...)
+		}
+		// Either way, always append.
+		interfaceOrdering = append(interfaceOrdering, t)
+	}
+	name := t.String() // Use the real type names for coders.
+	coderRegistry[key] = func(rt reflect.Type) *CustomCoder {
+		// We need to provide the concrete type, so that coders that use
+		// the reflect.Type have the proper instance.
+		cc, err := NewCustomCoder(name, rt, enc, dec)
+		if err != nil {
+			// An error on look up shouldn't happen after the validation.
+			panic(fmt.Sprintf("Creating %v CustomCoder for type %v failed: %v", name, rt, err))
+		}
+		return cc
+	}
+}
+
+// LookupCustomCoder returns the custom coder for the type if any,
+// first checking for a specific matching type, and then iterating
+// through registered interface coders in reverse registration order.
+func LookupCustomCoder(t reflect.Type) *CustomCoder {
+	key := tkey(t)
+	if maker, ok := coderRegistry[key]; ok {
+		return maker(t)
+	}
+	for i := len(interfaceOrdering) - 1; i >= 0; i-- {
+		iT := interfaceOrdering[i]
+		if t.Implements(iT) {
+			key := tkey(iT)
+			return coderRegistry[key](t)
+		}
+	}
+	return nil
+}
+
+// tkey returns the uintptr for a given type as the key.
+func tkey(t reflect.Type) uintptr {
+	return reflect.ValueOf(t).Pointer()
+}
diff --git a/sdks/go/pkg/beam/core/graph/coder/registry_test.go b/sdks/go/pkg/beam/core/graph/coder/registry_test.go
new file mode 100644
index 0000000..468be6e
--- /dev/null
+++ b/sdks/go/pkg/beam/core/graph/coder/registry_test.go
@@ -0,0 +1,322 @@
+// Licensed to the Apache Software Foundation (ASF) under one or more
+// contributor license agreements.  See the NOTICE file distributed with
+// this work for additional information regarding copyright ownership.
+// The ASF licenses this file to You under the Apache License, Version 2.0
+// (the "License"); you may not use this file except in compliance with
+// the License.  You may obtain a copy of the License at
+//
+//    http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package coder
+
+import (
+	"reflect"
+	"testing"
+
+	"github.com/apache/beam/sdks/go/pkg/beam/core/typex"
+)
+
+func clearRegistry() {
+	coderRegistry = make(map[uintptr]func(reflect.Type) *CustomCoder)
+	interfaceOrdering = []reflect.Type{}
+}
+
+type MyInt int
+
+func (MyInt) SatisfiesA() {}
+func (MyInt) SatisfiesB() {}
+
+type MyStruct struct{}
+
+func (*MyStruct) SatisfiesB() {} // Pointer Receiver
+func (MyStruct) SatisfiesC()  {} // Value Receiver
+
+type myA interface {
+	SatisfiesA()
+}
+type myB interface {
+	SatisfiesB()
+}
+type myC interface {
+	SatisfiesC()
+}
+
+var (
+	miType    = reflect.TypeOf((*MyInt)(nil)).Elem() // Implements myA & myB
+	msPtrType = reflect.TypeOf((*MyStruct)(nil))     // Implements myB & myC
+	msType    = msPtrType.Elem()                     // Implements MyC
+	aType     = reflect.TypeOf((*myA)(nil)).Elem()
+	bType     = reflect.TypeOf((*myB)(nil)).Elem()
+	cType     = reflect.TypeOf((*myC)(nil)).Elem()
+)
+
+// Encoders and Decoder functions for the above types.
+
+func miEnc(MyInt) []byte { return nil }
+func miDec([]byte) MyInt { return 0 }
+
+func msEnc(MyStruct) []byte     { return nil }
+func msDec([]byte) MyStruct     { return MyStruct{} }
+func msPtrEnc(*MyStruct) []byte { return nil }
+func msPtrDec([]byte) *MyStruct { return nil }
+
+func aEnc(myA) []byte { return nil }
+func aDec([]byte) myA { return nil }
+
+func bEnc(myB) []byte { return nil }
+func bDec([]byte) myB { return nil }
+
+func cEnc(myC) []byte { return nil }
+func cDec([]byte) myC { return nil }
+
+// General interface coder.
+func tEnc(typex.T) []byte { return nil }
+func tDec([]byte) typex.T { return nil }
+
+// TestRegisterCoder checks that RegisterCoder panics ag the panic behavior when validation fails.
+// All the examples below are negative examples, and are intended to fail.
+func TestRegisterCoder(t *testing.T) {
+	tests := []struct {
+		name     string
+		typ      reflect.Type
+		enc, dec interface{}
+	}{{
+		name: "nonSatisfyingInterface",
+		typ:  msType, enc: aEnc, dec: aDec,
+	}, {
+		name: "swapped",
+		typ:  msType, enc: msDec, dec: msEnc,
+	}, {
+		name: "mismatchedDec",
+		typ:  msType, enc: msEnc, dec: miDec,
+	}, {
+		name: "mismatchedEnc",
+		typ:  msType, enc: miEnc, dec: msDec,
+	}, {
+		// TODO(BEAM-6578)- Consider using the pointer coder if it exists.
+		name: "pointerType",
+		typ:  msType, enc: msPtrEnc, dec: msPtrDec,
+	}, {
+		name: "valueType",
+		typ:  msPtrType, enc: msEnc, dec: msDec,
+	}, {
+		name: "badEnc-inappropriateFn",
+		typ:  msType, enc: func() {}, dec: msDec,
+	}, {
+		name: "badDec-inappropriateFn",
+		typ:  msType, enc: msEnc, dec: func() {},
+	}, {
+		name: "badEnc-nil",
+		typ:  msType, enc: nil, dec: msDec,
+	}, {
+		name: "badDec-nil",
+		typ:  msType, enc: msEnc, dec: nil,
+	}, {
+		name: "badEnc-int",
+		typ:  msType, enc: 41, dec: msDec,
+	}, {
+		name: "badDec-int",
+		typ:  msType, enc: msEnc, dec: 43,
+	},
+	}
+
+	for _, test := range tests {
+		test := test
+		t.Run(test.name, func(t *testing.T) {
+			defer func() {
+				if p := recover(); p != nil {
+					t.Log(p)
+					return
+				}
+				t.Fatalf("RegisterCoder(%v, %T, %T): want panic", msType, test.enc, test.dec)
+			}()
+			RegisterCoder(test.typ, test.enc, test.dec)
+		})
+	}
+}
+
+func TestLookupCustomCoder(t *testing.T) {
+	newCC := func(t *testing.T, name string, et reflect.Type, enc, dec interface{}) *CustomCoder {
+		t.Helper()
+		cc, err := NewCustomCoder(name, et, enc, dec)
+		if err != nil {
+			t.Errorf("couldn't get custom coder for %v: %v", et, err)
+		}
+		return cc
+	}
+
+	miCC := newCC(t, miType.String(), miType, miEnc, miDec)
+	msCC := newCC(t, msType.String(), msType, msEnc, msDec)
+	msPtrCC := newCC(t, msPtrType.String(), msPtrType, msPtrEnc, msPtrDec)
+
+	tests := []struct {
+		name   string
+		config func(t *testing.T)
+		lookup reflect.Type
+		want   *CustomCoder
+	}{{
+		name:   "not_registered_specific",
+		lookup: miType,
+		want:   nil,
+		config: func(t *testing.T) {},
+	}, {
+		name:   "not_registered_interface",
+		lookup: cType,
+		want:   nil,
+		config: func(t *testing.T) {},
+	}, {
+		name:   "specific",
+		lookup: miType,
+		want:   miCC,
+		config: func(t *testing.T) {
+			RegisterCoder(msType, msEnc, msDec)
+			RegisterCoder(miType, miEnc, miDec)
+		},
+	}, {
+		name:   "interfaceDirect",
+		lookup: aType,
+		want:   newCC(t, aType.String(), aType, aEnc, aDec),
+		config: func(t *testing.T) {
+			RegisterCoder(aType, aEnc, aDec)
+		},
+	}, {
+		name:   "interfaceImplements",
+		lookup: miType,
+		want:   newCC(t, aType.String(), miType, aEnc, aDec),
+		config: func(t *testing.T) {
+			RegisterCoder(aType, aEnc, aDec)
+		},
+	}, {
+		name:   "interfaceNotImplemented",
+		lookup: miType,
+		want:   nil,
+		config: func(t *testing.T) {
+			RegisterCoder(cType, cEnc, cDec)
+		},
+	}, {
+		name:   "specificPrecedence",
+		lookup: miType,
+		want:   miCC,
+		config: func(t *testing.T) {
+			RegisterCoder(miType, miEnc, miDec)
+			RegisterCoder(aType, aEnc, aDec)
+		},
+	}, {
+		name:   "interfacePrecedence",
+		lookup: miType,
+		want:   newCC(t, bType.String(), miType, bEnc, bDec),
+		config: func(t *testing.T) {
+			RegisterCoder(aType, aEnc, aDec)
+			RegisterCoder(bType, bEnc, bDec)
+		},
+	}, {
+		name:   "interfacePrecedence2",
+		lookup: miType,
+		want:   newCC(t, bType.String(), miType, bEnc, bDec),
+		config: func(t *testing.T) {
+			RegisterCoder(aType, aEnc, aDec)
+			RegisterCoder(bType, bEnc, bDec)
+			RegisterCoder(cType, cEnc, cDec)
+			if got, want := len(interfaceOrdering), 3; got != want {
+				t.Fatalf("interfaceOrdering but has %v elements, want %v; contents: %v", got, want, interfaceOrdering)
+			}
+		},
+	}, {
+		name:   "interfacePrecedenceOverriding",
+		lookup: miType,
+		want:   newCC(t, aType.String(), miType, aEnc, aDec),
+		config: func(t *testing.T) {
+			RegisterCoder(aType, aEnc, aDec)
+			RegisterCoder(bType, bEnc, bDec)
+			RegisterCoder(aType, aEnc, aDec)
+			RegisterCoder(cType, cEnc, cDec)
+			if got, want := len(interfaceOrdering), 3; got != want {
+				t.Fatalf("interfaceOrdering but has %v elements, want %v; contents: %v", got, want, interfaceOrdering)
+			}
+		},
+	}, {
+		name:   "interfacePointerRecv",
+		lookup: msPtrType,
+		want:   newCC(t, bType.String(), msPtrType, bEnc, bDec),
+		config: func(t *testing.T) {
+			RegisterCoder(cType, cEnc, cDec)
+			RegisterCoder(bType, bEnc, bDec)
+		},
+	}, {
+		name:   "interfacePointerRecv_A",
+		lookup: msPtrType,
+		// Pointer Receivers also implement all value methods,
+		// so *MyStruct, implements myC. This means if an interface
+		// is registered late, that a pointer's value type satisfies,
+		// then the value type coder will be used.
+		want: newCC(t, cType.String(), msPtrType, cEnc, cDec),
+		config: func(t *testing.T) {
+			RegisterCoder(bType, bEnc, bDec)
+			RegisterCoder(cType, cEnc, cDec)
+		},
+	}, {
+		name:   "interfaceValueRecv",
+		lookup: msType,
+		want:   newCC(t, cType.String(), msType, cEnc, cDec),
+		config: func(t *testing.T) {
+			RegisterCoder(cType, cEnc, cDec) // Only one that satisfies.
+			RegisterCoder(bType, bEnc, bDec)
+		},
+	}, {
+		name:   "pointersPtrOverValue",
+		lookup: msPtrType,
+		want:   msPtrCC,
+		config: func(t *testing.T) {
+			RegisterCoder(msType, msEnc, msDec)
+			RegisterCoder(msPtrType, msPtrEnc, msPtrDec) // Should pick this one.
+			// Conflating interfaces.
+			RegisterCoder(bType, bEnc, bDec)
+			RegisterCoder(cType, cEnc, cDec)
+		},
+	}, {
+		name:   "pointersValueOverPointer",
+		lookup: msType,
+		want:   msCC,
+		config: func(t *testing.T) {
+			RegisterCoder(msType, msEnc, msDec) // Should pick this one.
+			RegisterCoder(msPtrType, msPtrEnc, msPtrDec)
+			// Conflating interfaces.
+			RegisterCoder(bType, bEnc, bDec)
+			RegisterCoder(cType, cEnc, cDec)
+		},
+	}, {
+		// TODO(BEAM-6578)- Consider using the pointer coder if it exists.
+		name:   "pointersNoDerefereningType",
+		lookup: msType,
+		want:   nil,
+		config: func(t *testing.T) {
+			RegisterCoder(msPtrType, msPtrEnc, msPtrDec)
+		},
+	}, {
+		name:   "specificNoIndirectingType",
+		lookup: msPtrType,
+		want:   nil,
+		config: func(t *testing.T) {
+			// Will never be relaxed, as it would lead to excess copying.
+			RegisterCoder(msType, msEnc, msDec)
+		},
+	},
+	}
+	for _, test := range tests {
+		test := test
+		t.Run(test.name, func(t *testing.T) {
+			clearRegistry()
+			test.config(t)
+			got := LookupCustomCoder(test.lookup)
+			if !test.want.Equals(got) {
+				t.Errorf("lookupCustomCoder(%v) = %v, want %v", test.lookup, got, test.want)
+			}
+		})
+	}
+}
diff --git a/sdks/go/pkg/beam/core/typex/class.go b/sdks/go/pkg/beam/core/typex/class.go
index 9f041d0..cb34ad7 100644
--- a/sdks/go/pkg/beam/core/typex/class.go
+++ b/sdks/go/pkg/beam/core/typex/class.go
@@ -21,7 +21,7 @@ import (
 	"unicode"
 	"unicode/utf8"
 
-	"github.com/golang/protobuf/proto"
+	"github.com/apache/beam/sdks/go/pkg/beam/core/util/reflectx"
 )
 
 // Class is the type "class" of data as distinguished by the runtime. The class
@@ -47,7 +47,22 @@ const (
 	Composite
 )
 
-var protoMessageType = reflect.TypeOf((*proto.Message)(nil)).Elem()
+func (c Class) String() string {
+	switch c {
+	case Invalid:
+		return "Invalid"
+	case Concrete:
+		return "Concrete"
+	case Universal:
+		return "Universal"
+	case Container:
+		return "Container"
+	case Composite:
+		return "Composite"
+	default:
+		panic(fmt.Sprintf("invalid Class value: %v", int(c)))
+	}
+}
 
 // TODO(herohde) 5/16/2017: maybe we should add more classes, so that every
 // reasonable type (such as error) is not Invalid, even though it is not
@@ -76,31 +91,41 @@ func ClassOf(t reflect.Type) Class {
 // data must be fully serializable. Functions and channels are examples of invalid
 // types. Aggregate types with no universals are considered concrete here.
 func IsConcrete(t reflect.Type) bool {
-	if t == nil || t == EventTimeType || t.Implements(WindowType) {
-		return false
-	}
+	return isConcrete(t, make(map[uintptr]bool))
+}
 
-	// TODO(BEAM-3306): the coder registry should be consulted here for user
-	// specified types and their coders.
-	if t.Implements(protoMessageType) {
+func isConcrete(t reflect.Type, visited map[uintptr]bool) bool {
+	// Check that we haven't hit a recursive loop.
+	key := reflect.ValueOf(t).Pointer()
+	// If there's an invalid field in a recursive type
+	// then the layer above will find it.
+	if visited[key] {
 		return true
 	}
+	visited[key] = true
+
+	// Handle special types.
+	if t == nil ||
+		t == EventTimeType ||
+		t.Implements(WindowType) ||
+		t == reflectx.Error ||
+		t == reflectx.Context ||
+		IsUniversal(t) {
+		return false
+	}
 
 	switch t.Kind() {
-	case reflect.Invalid, reflect.UnsafePointer, reflect.Uintptr, reflect.Interface:
+	case reflect.Invalid, reflect.UnsafePointer, reflect.Uintptr:
 		return false // no unmanageable types
 
 	case reflect.Chan, reflect.Func:
 		return false // no unserializable types
 
-	case reflect.Map, reflect.Array:
-		return false // TBD
-
-	case reflect.Slice:
-		return IsConcrete(t.Elem())
+	case reflect.Map:
+		return isConcrete(t.Elem(), visited) && isConcrete(t.Key(), visited)
 
-	case reflect.Ptr:
-		return false // TBD
+	case reflect.Array, reflect.Slice, reflect.Ptr:
+		return isConcrete(t.Elem(), visited)
 
 	case reflect.Struct:
 		for i := 0; i < t.NumField(); i++ {
@@ -112,13 +137,17 @@ func IsConcrete(t reflect.Type) bool {
 			f := t.Field(i)
 			if len(f.Name) > 0 {
 				r, _ := utf8.DecodeRuneInString(f.Name)
-				if unicode.IsUpper(r) && !IsConcrete(f.Type) {
+				if unicode.IsUpper(r) && !isConcrete(f.Type, visited) {
 					return false
 				}
 			}
 		}
 		return true
 
+	case reflect.Interface:
+		// Interface types must fail at construction time if no coder is registered for them.
+		return true
+
 	case reflect.Bool:
 		return true
 
@@ -145,6 +174,8 @@ func IsConcrete(t reflect.Type) bool {
 // IsContainer returns true iff the given type is an container data type,
 // such as []int or []T.
 func IsContainer(t reflect.Type) bool {
+	// TODO(lostluck) 2019.02.03: Should we consider maps a container for
+	// beam specific purposes?
 	switch {
 	case IsList(t):
 		if IsUniversal(t.Elem()) || IsConcrete(t.Elem()) {
diff --git a/sdks/go/pkg/beam/core/typex/class_test.go b/sdks/go/pkg/beam/core/typex/class_test.go
index 1572448..8fc250f 100644
--- a/sdks/go/pkg/beam/core/typex/class_test.go
+++ b/sdks/go/pkg/beam/core/typex/class_test.go
@@ -40,6 +40,10 @@ func TestClassOf(t *testing.T) {
 		{reflectx.Uint32, Concrete},
 		{reflectx.Uint64, Concrete},
 		{reflectx.String, Concrete},
+		{reflectx.Float32, Concrete},
+		{reflectx.Float64, Concrete},
+		{reflect.TypeOf(complex64(0)), Concrete},
+		{reflect.TypeOf(complex128(0)), Concrete},
 		{reflect.TypeOf(struct{ A int }{}), Concrete},
 		{reflect.TypeOf(struct {
 			A int
@@ -47,6 +51,31 @@ func TestClassOf(t *testing.T) {
 		}{}), Concrete},
 		{reflect.TypeOf(struct{ A []int }{}), Concrete},
 		{reflect.TypeOf(reflect.Value{}), Concrete}, // ok: private fields
+		{reflect.TypeOf(map[string]int{}), Concrete},
+		{reflect.TypeOf(map[string]func(){}), Invalid},
+		{reflect.TypeOf(map[error]int{}), Invalid},
+		{reflect.TypeOf([4]int{}), Concrete},
+		{reflect.TypeOf([1]string{}), Concrete},
+		{reflect.TypeOf([0]string{}), Concrete},
+		{reflect.TypeOf([3]struct{ Q []string }{}), Concrete},
+		{reflect.TypeOf([0]interface{}{}), Concrete},
+		{reflect.TypeOf([1]string{}), Concrete},
+		{reflect.TypeOf([0]string{}), Concrete},
+		{reflect.TypeOf([0]interface{}{}), Concrete},
+		{reflect.PtrTo(reflectx.String), Concrete},
+		{reflect.PtrTo(reflectx.Uint32), Concrete},
+		{reflect.PtrTo(reflectx.Bool), Concrete},
+		{reflect.TypeOf([4]int{}), Concrete},
+		{reflect.TypeOf(&struct{ A int }{}), Concrete},
+		{reflect.TypeOf(&struct{ A *int }{}), Concrete},
+		{reflect.TypeOf([4]int{}), Concrete},
+
+		// Recursive types.
+		{reflect.TypeOf(RecursivePtrTest{}), Concrete},
+		{reflect.TypeOf(RecursiveSliceTest{}), Concrete},
+		{reflect.TypeOf(RecursiveMapTest{}), Concrete},
+		{reflect.TypeOf(RecursivePtrArrayTest{}), Concrete},
+		{reflect.TypeOf(RecursiveBadTest{}), Invalid},
 
 		{reflect.TypeOf([]X{}), Container},
 		{reflect.TypeOf([][][]X{}), Container},
@@ -63,18 +92,20 @@ func TestClassOf(t *testing.T) {
 
 		{EventTimeType, Invalid},                                     // special
 		{WindowType, Invalid},                                        // special
-		{reflect.TypeOf((*ConcreteTestWindow)(nil)).Elem(), Invalid}, // also special
+		{reflectx.Context, Invalid},                                  // special
+		{reflectx.Error, Invalid},                                    // special
+		{reflect.TypeOf((*ConcreteTestWindow)(nil)).Elem(), Invalid}, // special
 
 		{KVType, Composite},
 		{CoGBKType, Composite},
 		{WindowedValueType, Composite},
 
-		{reflect.TypeOf((*interface{})(nil)).Elem(), Invalid}, // empty interface
-		{reflectx.Context, Invalid},                           // interface
-		{reflectx.Error, Invalid},                             // interface
-		{reflect.TypeOf(func() {}), Invalid},                  // function
-		{reflect.TypeOf(make(chan int)), Invalid},             // chan
-		{reflect.TypeOf(struct{ A error }{}), Invalid},        // public interface field
+		{reflect.TypeOf((*interface{})(nil)).Elem(), Concrete}, // special
+
+		{reflect.TypeOf(uintptr(0)), Invalid},          // uintptr
+		{reflect.TypeOf(func() {}), Invalid},           // function
+		{reflect.TypeOf(make(chan int)), Invalid},      // chan
+		{reflect.TypeOf(struct{ A error }{}), Invalid}, // public interface field
 	}
 
 	for _, test := range tests {
@@ -85,6 +116,29 @@ func TestClassOf(t *testing.T) {
 	}
 }
 
+type RecursivePtrTest struct {
+	Ptr *RecursivePtrTest
+}
+
+type RecursiveSliceTest struct {
+	Slice []RecursiveSliceTest
+}
+
+type RecursiveMapTest struct {
+	Map map[int]RecursiveMapTest
+}
+
+// The compiler catches recursive types without indirection.
+// Indirection makes recursion allowable.
+type RecursivePtrArrayTest struct {
+	Array [12]*RecursivePtrArrayTest
+}
+
+type RecursiveBadTest struct {
+	Map map[RecursivePtrTest]*RecursiveMapTest
+	Bad func()
+}
+
 type ConcreteTestWindow int
 
 func (ConcreteTestWindow) MaxTimestamp() EventTime {
diff --git a/sdks/go/pkg/beam/core/typex/fulltype.go b/sdks/go/pkg/beam/core/typex/fulltype.go
index a6ab4a3..e40aa7a 100644
--- a/sdks/go/pkg/beam/core/typex/fulltype.go
+++ b/sdks/go/pkg/beam/core/typex/fulltype.go
@@ -13,7 +13,8 @@
 // See the License for the specific language governing permissions and
 // limitations under the License.
 
-// Package typex contains full type representation and utilities for type checking.
+// Package typex contains full type representation for PCollections and DoFns, and
+// utilities for type checking.
 package typex
 
 import (
diff --git a/sdks/go/pkg/beam/core/typex/special.go b/sdks/go/pkg/beam/core/typex/special.go
index 0a9991e..003e6d0 100644
--- a/sdks/go/pkg/beam/core/typex/special.go
+++ b/sdks/go/pkg/beam/core/typex/special.go
@@ -16,8 +16,9 @@
 package typex
 
 import (
-	"github.com/apache/beam/sdks/go/pkg/beam/core/graph/mtime"
 	"reflect"
+
+	"github.com/apache/beam/sdks/go/pkg/beam/core/graph/mtime"
 )
 
 // This file defines data types that programs use to indicate a
diff --git a/sdks/go/pkg/beam/forward.go b/sdks/go/pkg/beam/forward.go
index a76568f..af722d5 100644
--- a/sdks/go/pkg/beam/forward.go
+++ b/sdks/go/pkg/beam/forward.go
@@ -18,6 +18,7 @@ package beam
 import (
 	"reflect"
 
+	"github.com/apache/beam/sdks/go/pkg/beam/core/graph/coder"
 	"github.com/apache/beam/sdks/go/pkg/beam/core/runtime"
 	"github.com/apache/beam/sdks/go/pkg/beam/core/typex"
 )
@@ -58,6 +59,47 @@ func RegisterInit(hook func()) {
 	runtime.RegisterInit(hook)
 }
 
+// RegisterCoder registers a user defined coder for a given type, and will
+// be used if there is no existing beam coder for that type.
+// Must be called prior to beam.Init(), preferably in an init() function.
+//
+// The coder used for a given type follows this ordering:
+//   1. Coders for Known Beam types.
+//   2. Coders registered for specific types
+//   3. Coders registered for interfaces types
+//   4. Default coder (JSON)
+//
+// Coders for interface types are iterated over to check if a type
+// satisfies them, and the most recent one registered will be used.
+//
+// Repeated registrations of the same type overrides prior ones.
+//
+// RegisterCoder additionally registers the type, and coder functions
+// as per RegisterType and RegisterFunction to avoid redundant calls.
+//
+// Supported Encoder Signatures
+//
+//  func(T) []byte
+//  func(reflect.Type, T) []byte
+//  func(T) ([]byte, error)
+//  func(reflect.Type, T) ([]byte, error)
+//
+// Supported Decoder Signatures
+//
+//  func([]byte) T
+//  func(reflect.Type, []byte) T
+//  func([]byte) (T, error)
+//  func(reflect.Type, []byte) (T, error)
+//
+// where T is the matching user type.
+//
+func RegisterCoder(t reflect.Type, encoder, decoder interface{}) {
+	runtime.RegisterType(t)
+	runtime.RegisterFunction(encoder)
+	runtime.RegisterFunction(decoder)
+	coder.RegisterCoder(t, encoder, decoder)
+}
+
 // Init is the hook that all user code must call after flags processing and
 // other static initialization, for now.
 func Init() {
diff --git a/sdks/go/pkg/beam/util.go b/sdks/go/pkg/beam/util.go
index 52ef117..aa0ff7c 100644
--- a/sdks/go/pkg/beam/util.go
+++ b/sdks/go/pkg/beam/util.go
@@ -16,7 +16,7 @@
 package beam
 
 //go:generate go install github.com/apache/beam/sdks/go/cmd/starcgen
-//go:generate starcgen --package=beam --identifiers=addFixedKeyFn,dropKeyFn,dropValueFn,swapKVFn,explodeFn,JSONDec,JSONEnc,ProtoEnc,ProtoDec,makePartitionFn,createFn
+//go:generate starcgen --package=beam --identifiers=addFixedKeyFn,dropKeyFn,dropValueFn,swapKVFn,explodeFn,jsonDec,jsonEnc,protoEnc,protoDec,makePartitionFn,createFn
 
 // We have some freedom to create various utilities, users can use depending on
 // preferences. One point of keeping Pipeline transformation functions plain Go


Mime
View raw message