From 5b5ccc1b6ade6a2ef7233117d8bd80be52bc25b5 Mon Sep 17 00:00:00 2001 From: hossinasaadi Date: Wed, 26 Mar 2025 17:40:31 +0400 Subject: [PATCH] implement rule restriction --- app/dispatcher/default.go | 10 +++++ app/router/command/config.go | 6 +++ app/router/config.go | 7 ++- app/router/restriction.go | 83 ++++++++++++++++++++++++++++++++++++ app/router/router.go | 30 +++++++------ features/routing/router.go | 11 +++++ infra/conf/router.go | 32 +++++++++----- 7 files changed, 152 insertions(+), 27 deletions(-) create mode 100644 app/router/restriction.go diff --git a/app/dispatcher/default.go b/app/dispatcher/default.go index 7bc58056..612ff97e 100644 --- a/app/dispatcher/default.go +++ b/app/dispatcher/default.go @@ -435,6 +435,16 @@ func (d *DefaultDispatcher) routedDispatch(ctx context.Context, link *transport. } else if d.router != nil { if route, err := d.router.PickRoute(routingLink); err == nil { outTag := route.GetOutboundTag() + restriction := route.GetRestriction() + if restriction != nil { + sessionInbounds := session.InboundFromContext(ctx) + userIP := sessionInbounds.Source.Address.IP() + err = d.router.RestrictionRule(restriction, userIP) + if err != nil { + errors.LogError(ctx, err) + } + } + if h := d.ohm.GetHandler(outTag); h != nil { isPickRoute = 2 if route.GetRuleTag() == "" { diff --git a/app/router/command/config.go b/app/router/command/config.go index 8c2e1343..b5957178 100644 --- a/app/router/command/config.go +++ b/app/router/command/config.go @@ -4,6 +4,7 @@ import ( "strings" "github.com/xtls/xray-core/common/net" + route "github.com/xtls/xray-core/common/route" "github.com/xtls/xray-core/features/routing" ) @@ -32,6 +33,10 @@ func (c routingContext) GetRuleTag() string { return "" } +func (c routingContext) GetRestriction() *route.Restriction { + return c.Restriction +} + // GetSkipDNSResolve is a mock implementation here to match the interface, // SkipDNSResolve is set from dns module, no use if coming from a protobuf object? // TODO: please confirm @Vigilans @@ -62,6 +67,7 @@ var fieldMap = map[string]func(*RoutingContext, routing.Route){ "attributes": func(s *RoutingContext, r routing.Route) { s.Attributes = r.GetAttributes() }, "outbound_group": func(s *RoutingContext, r routing.Route) { s.OutboundGroupTags = r.GetOutboundGroupTags() }, "outbound": func(s *RoutingContext, r routing.Route) { s.OutboundTag = r.GetOutboundTag() }, + "restriction": func(s *RoutingContext, r routing.Route) { s.Restriction = r.GetRestriction() }, } // AsProtobufMessage takes selectors of fields and returns a function to convert routing.Route to protobuf RoutingContext. diff --git a/app/router/config.go b/app/router/config.go index f1740610..6a77d1d1 100644 --- a/app/router/config.go +++ b/app/router/config.go @@ -11,17 +11,16 @@ import ( ) type Rule struct { - Tag string - RuleTag string + *RoutingRule Balancer *Balancer Condition Condition } -func (r *Rule) GetTag() (string, error) { +func (r *Rule) GetTargetTag() (string, error) { if r.Balancer != nil { return r.Balancer.PickOutbound() } - return r.Tag, nil + return r.GetTag(), nil } // Apply checks rule matching of current routing context. diff --git a/app/router/restriction.go b/app/router/restriction.go new file mode 100644 index 00000000..88cf382c --- /dev/null +++ b/app/router/restriction.go @@ -0,0 +1,83 @@ +package router + +import ( + "time" + + "github.com/xtls/xray-core/common/errors" + "github.com/xtls/xray-core/common/net" + route "github.com/xtls/xray-core/common/route" +) + +// RestrictionRule implements routing.Router. +func (r *Router) RestrictionRule(restriction *route.Restriction, ip net.IP) error { + + mask := uint32(128) + if ip.To4() != nil { + mask = 32 + } + + cidrList := []*CIDR{{Ip: ip, Prefix: mask}} + sourceIP := &GeoIP{Cidr: cidrList} + + // If rule exists, process it + if r.RuleExists(restriction.Tag) { + for _, rule := range r.rules { + if rule.RuleTag == restriction.Tag { + if shouldCleanup(restriction) { + errors.LogWarning(r.ctx, "restriction cleanup -> ", restriction.Tag, " after ", restriction.CleanInterval, " seconds") + restriction.LastCleanup = time.Now().Unix() + return r.RemoveRule(restriction.Tag) + } + + // If cleanup is not running, schedule it + scheduleCleanup(r, restriction) + + // Check if IP already exists in the list + if ipExistsInRestriction(rule.RoutingRule.SourceGeoip, cidrList[0]) { + return errors.New(ip.String(), " already exists in restriction list.") + } + + rule.RoutingRule.SourceGeoip = append(rule.RoutingRule.SourceGeoip, sourceIP) + r.RemoveRule(restriction.Tag) + return r.ReloadRules(&Config{Rule: []*RoutingRule{rule.RoutingRule}}, true) + } + } + } + + // If rule does not exist, create a new one + newRule := &RoutingRule{ + RuleTag: restriction.Tag, + TargetTag: &RoutingRule_Tag{Tag: restriction.OutboundTag}, + SourceGeoip: []*GeoIP{sourceIP}, + } + + errors.LogWarning(r.ctx, "restrict IP -> ", ip.String(), " for route violation.") + return r.ReloadRules(&Config{Rule: []*RoutingRule{newRule}}, true) +} + +func shouldCleanup(restriction *route.Restriction) bool { + return time.Now().Unix()-restriction.LastCleanup >= restriction.CleanInterval && restriction.LastCleanup != 0 +} + +func scheduleCleanup(r *Router, restriction *route.Restriction) { + if !r.isCleanupRunning && restriction.CleanInterval != 0 { + r.isCleanupRunning = true + restriction.LastCleanup = time.Now().Unix() + + go time.AfterFunc(time.Duration(restriction.CleanInterval)*time.Second, func() { + r.RestrictionRule(restriction, nil) + r.isCleanupRunning = false + }) + } +} + +func ipExistsInRestriction(sourceGeoip []*GeoIP, newCIDR *CIDR) bool { + for _, source := range sourceGeoip { + for _, cidr := range source.GetCidr() { + if string(cidr.GetIp()) == string(newCIDR.GetIp()) { + return true + } + } + } + return false +} diff --git a/app/router/router.go b/app/router/router.go index 2f35b3e7..7e79c15e 100644 --- a/app/router/router.go +++ b/app/router/router.go @@ -6,6 +6,7 @@ import ( "github.com/xtls/xray-core/common" "github.com/xtls/xray-core/common/errors" + route "github.com/xtls/xray-core/common/route" "github.com/xtls/xray-core/common/serial" "github.com/xtls/xray-core/core" "github.com/xtls/xray-core/features/dns" @@ -21,10 +22,11 @@ type Router struct { balancers map[string]*Balancer dns dns.Client - ctx context.Context - ohm outbound.Manager - dispatcher routing.Dispatcher - mu sync.Mutex + ctx context.Context + ohm outbound.Manager + dispatcher routing.Dispatcher + mu sync.Mutex + isCleanupRunning bool } // Route is an implementation of routing.Route. @@ -33,6 +35,7 @@ type Route struct { outboundGroupTags []string outboundTag string ruleTag string + restriction *route.Restriction } // Init initializes the Router. @@ -60,9 +63,8 @@ func (r *Router) Init(ctx context.Context, config *Config, d dns.Client, ohm out return err } rr := &Rule{ - Condition: cond, - Tag: rule.GetTag(), - RuleTag: rule.GetRuleTag(), + RoutingRule: rule, + Condition: cond, } btag := rule.GetBalancingTag() if len(btag) > 0 { @@ -84,11 +86,11 @@ func (r *Router) PickRoute(ctx routing.Context) (routing.Route, error) { if err != nil { return nil, err } - tag, err := rule.GetTag() + tag, err := rule.GetTargetTag() if err != nil { return nil, err } - return &Route{Context: ctx, outboundTag: tag, ruleTag: rule.RuleTag}, nil + return &Route{Context: ctx, outboundTag: tag, ruleTag: rule.RuleTag, restriction: rule.Restriction}, nil } // AddRule implements routing.Router. @@ -134,9 +136,8 @@ func (r *Router) ReloadRules(config *Config, shouldAppend bool) error { return err } rr := &Rule{ - Condition: cond, - Tag: rule.GetTag(), - RuleTag: rule.GetRuleTag(), + RoutingRule: rule, + Condition: cond, } btag := rule.GetBalancingTag() if len(btag) > 0 { @@ -242,6 +243,11 @@ func (r *Route) GetRuleTag() string { return r.ruleTag } +// GetRestriction implements routing.Route. +func (r *Route) GetRestriction() *route.Restriction { + return r.restriction +} + func init() { common.Must(common.RegisterConfig((*Config)(nil), func(ctx context.Context, config interface{}) (interface{}, error) { r := new(Router) diff --git a/features/routing/router.go b/features/routing/router.go index 174d59fd..53761434 100644 --- a/features/routing/router.go +++ b/features/routing/router.go @@ -2,6 +2,8 @@ package routing import ( "github.com/xtls/xray-core/common" + "github.com/xtls/xray-core/common/net" + "github.com/xtls/xray-core/common/route" "github.com/xtls/xray-core/common/serial" "github.com/xtls/xray-core/features" ) @@ -16,6 +18,7 @@ type Router interface { PickRoute(ctx Context) (Route, error) AddRule(config *serial.TypedMessage, shouldAppend bool) error RemoveRule(tag string) error + RestrictionRule(restriction *route.Restriction, ip net.IP) error } // Route is the routing result of Router feature. @@ -33,6 +36,9 @@ type Route interface { // GetRuleTag returns the matching rule tag for debugging if exists GetRuleTag() string + + // GetRestriction. + GetRestriction() *route.Restriction } // RouterType return the type of Router interface. Can be used to implement common.HasType. @@ -65,6 +71,11 @@ func (DefaultRouter) RemoveRule(tag string) error { return common.ErrNoClue } +// RestrictionRule implements Router. +func (DefaultRouter) RestrictionRule(restriction *route.Restriction, ip net.IP) error { + return common.ErrNoClue +} + // Start implements common.Runnable. func (DefaultRouter) Start() error { return nil diff --git a/infra/conf/router.go b/infra/conf/router.go index 2065f96b..0b2fdc85 100644 --- a/infra/conf/router.go +++ b/infra/conf/router.go @@ -10,6 +10,7 @@ import ( "github.com/xtls/xray-core/common/errors" "github.com/xtls/xray-core/common/net" "github.com/xtls/xray-core/common/platform/filesystem" + "github.com/xtls/xray-core/common/route" "github.com/xtls/xray-core/common/serial" "google.golang.org/protobuf/proto" ) @@ -531,17 +532,18 @@ func ToCidrList(ips StringList) ([]*router.GeoIP, error) { func parseFieldRule(msg json.RawMessage) (*router.RoutingRule, error) { type RawFieldRule struct { RouterRule - Domain *StringList `json:"domain"` - Domains *StringList `json:"domains"` - IP *StringList `json:"ip"` - Port *PortList `json:"port"` - Network *NetworkList `json:"network"` - SourceIP *StringList `json:"source"` - SourcePort *PortList `json:"sourcePort"` - User *StringList `json:"user"` - InboundTag *StringList `json:"inboundTag"` - Protocols *StringList `json:"protocol"` - Attributes map[string]string `json:"attrs"` + Domain *StringList `json:"domain"` + Domains *StringList `json:"domains"` + IP *StringList `json:"ip"` + Port *PortList `json:"port"` + Network *NetworkList `json:"network"` + SourceIP *StringList `json:"source"` + SourcePort *PortList `json:"sourcePort"` + User *StringList `json:"user"` + InboundTag *StringList `json:"inboundTag"` + Protocols *StringList `json:"protocol"` + Attributes map[string]string `json:"attrs"` + Restriction *route.Restriction `json:"restriction"` } rawFieldRule := new(RawFieldRule) err := json.Unmarshal(msg, rawFieldRule) @@ -638,6 +640,14 @@ func parseFieldRule(msg json.RawMessage) (*router.RoutingRule, error) { rule.Attributes = rawFieldRule.Attributes } + if rawFieldRule.Restriction != nil { + // set parent outbound tag + if rawFieldRule.Restriction.OutboundTag == "" { + rawFieldRule.Restriction.OutboundTag = rawFieldRule.OutboundTag + } + rule.Restriction = rawFieldRule.Restriction + } + return rule, nil }