diff --git a/listener/tproxy/udp.go b/listener/tproxy/udp.go index 52becbc9..f738ef0d 100644 --- a/listener/tproxy/udp.go +++ b/listener/tproxy/udp.go @@ -74,12 +74,14 @@ func NewUDP(addr string, tunnel C.Tunnel, additions ...inbound.Addition) (*UDPLi continue } - rAddr, dscp, err := getOrigDstAndDSCP(oob[:oobn]) - additions = append(additions, inbound.WithDSCP(dscp)) + rAddr, err := getOrigDst(oob[:oobn]) if err != nil { continue } + dscp, _ := getDSCP(oob[:oobn]) + additions = append(additions, inbound.WithDSCP(dscp)) + if rAddr.Addr().Is4() { // try to unmap 4in6 address lAddr = netip.AddrPortFrom(lAddr.Addr().Unmap(), lAddr.Port()) diff --git a/listener/tproxy/udp_linux.go b/listener/tproxy/udp_linux.go index eee09c00..f7f0fd0f 100644 --- a/listener/tproxy/udp_linux.go +++ b/listener/tproxy/udp_linux.go @@ -96,23 +96,22 @@ func udpAddrFamily(net string, lAddr, rAddr netip.AddrPort) int { return syscall.AF_INET6 } -func getOrigDstAndDSCP(oob []byte) (netip.AddrPort, uint8, error) { +func getOrigDst(oob []byte) (netip.AddrPort, error) { // oob contains socket control messages which we need to parse. scms, err := unix.ParseSocketControlMessage(oob) if err != nil { - return netip.AddrPort{}, 0, fmt.Errorf("parse control message: %w", err) + return netip.AddrPort{}, fmt.Errorf("parse control message: %w", err) } // retrieve the destination address from the SCM. sa, err := unix.ParseOrigDstAddr(&scms[1]) + if err != nil { - return netip.AddrPort{}, 0, fmt.Errorf("retrieve destination: %w", err) + return netip.AddrPort{}, fmt.Errorf("retrieve destination: %w", err) } - // retrieve DSCP from the SCM - dscp, err := parseDSCP(&scms[0]) if err != nil { - return netip.AddrPort{}, 0, fmt.Errorf("retrieve DSCP: %w", err) + return netip.AddrPort{}, fmt.Errorf("retrieve DSCP: %w", err) } // encode the destination address into a cmsg. @@ -123,10 +122,22 @@ func getOrigDstAndDSCP(oob []byte) (netip.AddrPort, uint8, error) { case *unix.SockaddrInet6: rAddr = netip.AddrPortFrom(netip.AddrFrom16(v.Addr), uint16(v.Port)) default: - return netip.AddrPort{}, 0, fmt.Errorf("unsupported address type: %T", v) + return netip.AddrPort{}, fmt.Errorf("unsupported address type: %T", v) } - return rAddr, dscp, nil + return rAddr, nil +} + +func getDSCP (oob []byte) (uint8, error) { + scms, err := unix.ParseSocketControlMessage(oob) + if err != nil { + return 0, fmt.Errorf("parse control message: %w", err) + } + dscp, err := parseDSCP(&scms[0]) + if err != nil { + return 0, fmt.Errorf("retrieve DSCP: %w", err) + } + return dscp, nil } func parseDSCP(m *unix.SocketControlMessage) (uint8, error) { diff --git a/listener/tproxy/udp_other.go b/listener/tproxy/udp_other.go index b35b07dd..2e0e0ae7 100644 --- a/listener/tproxy/udp_other.go +++ b/listener/tproxy/udp_other.go @@ -12,6 +12,10 @@ func getOrigDst(oob []byte) (netip.AddrPort, error) { return netip.AddrPort{}, errors.New("UDP redir not supported on current platform") } +func getDSCP(oob []byte) (uint8, error) { + return 0, errors.New("UDP redir not supported on current platform") +} + func dialUDP(network string, lAddr, rAddr netip.AddrPort) (*net.UDPConn, error) { return nil, errors.New("UDP redir not supported on current platform") }