From 487d7fa81fd2748ca7fe7a5be161b3159702ab3b Mon Sep 17 00:00:00 2001 From: wwqgtxx Date: Wed, 9 Apr 2025 17:53:36 +0800 Subject: [PATCH] fix: panic under some stupid input config --- adapter/outbound/reality.go | 11 +++++++---- listener/reality/reality.go | 6 +++++- 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/adapter/outbound/reality.go b/adapter/outbound/reality.go index 6711f378..e5b090a8 100644 --- a/adapter/outbound/reality.go +++ b/adapter/outbound/reality.go @@ -20,16 +20,19 @@ func (o RealityOptions) Parse() (*tlsC.RealityConfig, error) { config := new(tlsC.RealityConfig) const x25519ScalarSize = 32 - var publicKey [x25519ScalarSize]byte - n, err := base64.RawURLEncoding.Decode(publicKey[:], []byte(o.PublicKey)) - if err != nil || n != x25519ScalarSize { + publicKey, err := base64.RawURLEncoding.DecodeString(o.PublicKey) + if err != nil || len(publicKey) != x25519ScalarSize { return nil, errors.New("invalid REALITY public key") } - config.PublicKey, err = ecdh.X25519().NewPublicKey(publicKey[:]) + config.PublicKey, err = ecdh.X25519().NewPublicKey(publicKey) if err != nil { return nil, fmt.Errorf("fail to create REALITY public key: %w", err) } + n := hex.DecodedLen(len(o.ShortID)) + if n > tlsC.RealityMaxShortIDLen { + return nil, errors.New("invalid REALITY short id") + } n, err = hex.Decode(config.ShortID[:], []byte(o.ShortID)) if err != nil || n > tlsC.RealityMaxShortIDLen { return nil, errors.New("invalid REALITY short ID") diff --git a/listener/reality/reality.go b/listener/reality/reality.go index f629f2b4..8157844d 100644 --- a/listener/reality/reality.go +++ b/listener/reality/reality.go @@ -50,7 +50,11 @@ func (c Config) Build() (*Builder, error) { realityConfig.ShortIds = make(map[[8]byte]bool) for i, shortIDString := range c.ShortID { var shortID [8]byte - decodedLen, err := hex.Decode(shortID[:], []byte(shortIDString)) + decodedLen := hex.DecodedLen(len(shortIDString)) + if decodedLen > 8 { + return nil, fmt.Errorf("invalid short_id[%d]: %s", i, shortIDString) + } + decodedLen, err = hex.Decode(shortID[:], []byte(shortIDString)) if err != nil { return nil, fmt.Errorf("decode short_id[%d] '%s': %w", i, shortIDString, err) }