From 0117c80ec89859c1ac0a32b358c3b99d22a2ceef Mon Sep 17 00:00:00 2001
From: esrrhs <esrrhs@163.com>
Date: Tue, 8 Jan 2019 10:31:52 +0800
Subject: [PATCH] fix

---
 client.go     | 17 +++++++++++++----
 pingtunnel.go | 18 +++++++++++++-----
 server.go     | 17 +++++++++++++----
 3 files changed, 39 insertions(+), 13 deletions(-)

diff --git a/client.go b/client.go
index 5538977..ac3deb9 100644
--- a/client.go
+++ b/client.go
@@ -9,7 +9,7 @@ import (
 	"time"
 )
 
-func NewClient(addr string, server string, target string, timeout int, sproto int, rproto int, catch int) (*Client, error) {
+func NewClient(addr string, server string, target string, timeout int, sproto int, rproto int, catch int, key int) (*Client, error) {
 
 	ipaddr, err := net.ResolveUDPAddr("udp", addr)
 	if err != nil {
@@ -33,6 +33,7 @@ func NewClient(addr string, server string, target string, timeout int, sproto in
 		sproto:       sproto,
 		rproto:       rproto,
 		catch:        catch,
+		key:          key,
 	}, nil
 }
 
@@ -44,6 +45,7 @@ type Client struct {
 	sproto  int
 	rproto  int
 	catch   int
+	key     int
 
 	ipaddr *net.UDPAddr
 	addr   string
@@ -180,7 +182,8 @@ func (p *Client) Accept() error {
 		}
 
 		clientConn.activeTime = now
-		sendICMP(p.id, p.sequence, *p.conn, p.ipaddrServer, p.targetAddr, clientConn.id, (uint32)(DATA), bytes[:n], p.sproto, p.rproto, p.catch)
+		sendICMP(p.id, p.sequence, *p.conn, p.ipaddrServer, p.targetAddr, clientConn.id, (uint32)(DATA), bytes[:n],
+			p.sproto, p.rproto, p.catch, p.key)
 
 		p.sequence++
 
@@ -195,6 +198,10 @@ func (p *Client) processPacket(packet *Packet) {
 		return
 	}
 
+	if packet.key != p.key {
+		return
+	}
+
 	if packet.msgType == PING {
 		t := time.Time{}
 		t.UnmarshalBinary(packet.data)
@@ -259,7 +266,8 @@ func (p *Client) ping() {
 	if p.sendPacket == 0 {
 		now := time.Now()
 		b, _ := now.MarshalBinary()
-		sendICMP(p.id, p.sequence, *p.conn, p.ipaddrServer, p.targetAddr, "", (uint32)(PING), b, p.sproto, p.rproto, p.catch)
+		sendICMP(p.id, p.sequence, *p.conn, p.ipaddrServer, p.targetAddr, "", (uint32)(PING), b,
+			p.sproto, p.rproto, p.catch, p.key)
 		fmt.Printf("ping %s %s %d %d %d %d\n", p.addrServer, now.String(), p.sproto, p.rproto, p.id, p.sequence)
 		p.sequence++
 	}
@@ -279,7 +287,8 @@ func (p *Client) showNet() {
 func (p *Client) sendCatch() {
 	if p.catch > 0 {
 		for _, conn := range p.localIdToConnMap {
-			sendICMP(p.id, p.sequence, *p.conn, p.ipaddrServer, p.targetAddr, conn.id, (uint32)(CATCH), make([]byte, 0), p.sproto, p.rproto, p.catch)
+			sendICMP(p.id, p.sequence, *p.conn, p.ipaddrServer, p.targetAddr, conn.id, (uint32)(CATCH), make([]byte, 0),
+				p.sproto, p.rproto, p.catch, p.key)
 			p.sequence++
 			p.sendCatchPacket++
 		}
diff --git a/pingtunnel.go b/pingtunnel.go
index e52c02d..ee14a7f 100644
--- a/pingtunnel.go
+++ b/pingtunnel.go
@@ -29,6 +29,7 @@ type MyMsg struct {
 	Data    []byte
 	RPROTO  uint16
 	CATCH   uint16
+	KEY     uint32
 	ENDTYPE uint32
 }
 
@@ -37,7 +38,7 @@ func (p *MyMsg) Len(proto int) int {
 	if p == nil {
 		return 0
 	}
-	return 4 + p.LenString(p.ID) + p.LenString(p.TARGET) + p.LenData(p.Data) + 2 + 2 + 4
+	return 4 + p.LenString(p.ID) + p.LenString(p.TARGET) + p.LenData(p.Data) + 2 + 2 + 4 + 4
 }
 
 func (p *MyMsg) LenString(s string) int {
@@ -68,7 +69,9 @@ func (p *MyMsg) Marshal(proto int) ([]byte, error) {
 
 	binary.BigEndian.PutUint16(b[4+p.LenString(p.ID)+p.LenString(p.TARGET)+p.LenData(p.Data)+2:], uint16(p.CATCH))
 
-	binary.BigEndian.PutUint32(b[4+p.LenString(p.ID)+p.LenString(p.TARGET)+p.LenData(p.Data)+4:], uint32(p.ENDTYPE))
+	binary.BigEndian.PutUint32(b[4+p.LenString(p.ID)+p.LenString(p.TARGET)+p.LenData(p.Data)+4:], uint32(p.KEY))
+
+	binary.BigEndian.PutUint32(b[4+p.LenString(p.ID)+p.LenString(p.TARGET)+p.LenData(p.Data)+8:], uint32(p.ENDTYPE))
 
 	return b, nil
 }
@@ -105,7 +108,9 @@ func (p *MyMsg) Unmarshal(b []byte) error {
 
 	p.CATCH = binary.BigEndian.Uint16(b[4+p.LenString(p.ID)+p.LenString(p.TARGET)+p.LenData(p.Data)+2:])
 
-	p.ENDTYPE = binary.BigEndian.Uint32(b[4+p.LenString(p.ID)+p.LenString(p.TARGET)+p.LenData(p.Data)+4:])
+	p.KEY = binary.BigEndian.Uint32(b[4+p.LenString(p.ID)+p.LenString(p.TARGET)+p.LenData(p.Data)+4:])
+
+	p.ENDTYPE = binary.BigEndian.Uint32(b[4+p.LenString(p.ID)+p.LenString(p.TARGET)+p.LenData(p.Data)+8:])
 
 	return nil
 }
@@ -131,7 +136,7 @@ func (p *MyMsg) UnmarshalData(b []byte) []byte {
 }
 
 func sendICMP(id int, sequence int, conn icmp.PacketConn, server *net.IPAddr, target string,
-	connId string, msgType uint32, data []byte, sproto int, rproto int, catch int) {
+	connId string, msgType uint32, data []byte, sproto int, rproto int, catch int, key int) {
 
 	m := &MyMsg{
 		ID:      connId,
@@ -140,6 +145,7 @@ func sendICMP(id int, sequence int, conn icmp.PacketConn, server *net.IPAddr, ta
 		Data:    data,
 		RPROTO:  (uint16)(rproto),
 		CATCH:   (uint16)(catch),
+		KEY:     (uint32)(key),
 		ENDTYPE: END,
 	}
 
@@ -217,7 +223,8 @@ func recvICMP(conn icmp.PacketConn, recv chan<- *Packet) {
 
 		recv <- &Packet{msgType: my.TYPE, data: my.Data, id: my.ID, target: my.TARGET,
 			src: srcaddr.(*net.IPAddr), rproto: (int)((int16)(my.RPROTO)),
-			echoId: echoId, echoSeq: echoSeq, catch: (int)((int16)(my.CATCH))}
+			echoId: echoId, echoSeq: echoSeq, catch: (int)((int16)(my.CATCH)),
+			key: (int)(my.KEY)}
 	}
 }
 
@@ -231,6 +238,7 @@ type Packet struct {
 	echoId  int
 	echoSeq int
 	catch   int
+	key     int
 }
 
 func UniqueId() string {
diff --git a/server.go b/server.go
index a39890d..45f397b 100644
--- a/server.go
+++ b/server.go
@@ -7,14 +7,16 @@ import (
 	"time"
 )
 
-func NewServer(timeout int) (*Server, error) {
+func NewServer(timeout int, key int) (*Server, error) {
 	return &Server{
 		timeout: timeout,
+		key:     key,
 	}, nil
 }
 
 type Server struct {
 	timeout int
+	key     int
 
 	conn *icmp.PacketConn
 
@@ -73,6 +75,10 @@ func (p *Server) Run() {
 
 func (p *Server) processPacket(packet *Packet) {
 
+	if packet.key != p.key {
+		return
+	}
+
 	p.echoId = packet.echoId
 	p.echoSeq = packet.echoSeq
 
@@ -80,7 +86,8 @@ func (p *Server) processPacket(packet *Packet) {
 		t := time.Time{}
 		t.UnmarshalBinary(packet.data)
 		fmt.Printf("ping from %s %s %d %d %d\n", packet.src.String(), t.String(), packet.rproto, packet.echoId, packet.echoSeq)
-		sendICMP(packet.echoId, packet.echoSeq, *p.conn, packet.src, "", "", (uint32)(PING), packet.data, packet.rproto, -1, 0)
+		sendICMP(packet.echoId, packet.echoSeq, *p.conn, packet.src, "", "", (uint32)(PING), packet.data,
+			packet.rproto, -1, 0, p.key)
 		return
 	}
 
@@ -121,7 +128,8 @@ func (p *Server) processPacket(packet *Packet) {
 	if packet.msgType == CATCH {
 		select {
 		case re := <-udpConn.catchQueue:
-			sendICMP(packet.echoId, packet.echoSeq, *p.conn, re.src, "", re.id, (uint32)(CATCH), re.data, re.conn.rproto, -1, 0)
+			sendICMP(packet.echoId, packet.echoSeq, *p.conn, re.src, "", re.id, (uint32)(CATCH), re.data,
+				re.conn.rproto, -1, 0, p.key)
 			p.sendCatchPacket++
 		case <-time.After(time.Duration(1) * time.Millisecond):
 		}
@@ -174,7 +182,8 @@ func (p *Server) Recv(conn *ServerConn, id string, src *net.IPAddr) {
 			case <-time.After(time.Duration(10) * time.Millisecond):
 			}
 		} else {
-			sendICMP(p.echoId, p.echoSeq, *p.conn, src, "", id, (uint32)(DATA), bytes[:n], conn.rproto, -1, 0)
+			sendICMP(p.echoId, p.echoSeq, *p.conn, src, "", id, (uint32)(DATA), bytes[:n],
+				conn.rproto, -1, 0, p.key)
 		}
 
 		p.sendPacket++