implement rule restriction

This commit is contained in:
hossinasaadi 2025-03-26 17:40:31 +04:00
parent 5fe51c2384
commit 5b5ccc1b6a
7 changed files with 152 additions and 27 deletions

View file

@ -435,6 +435,16 @@ func (d *DefaultDispatcher) routedDispatch(ctx context.Context, link *transport.
} else if d.router != nil {
if route, err := d.router.PickRoute(routingLink); err == nil {
outTag := route.GetOutboundTag()
restriction := route.GetRestriction()
if restriction != nil {
sessionInbounds := session.InboundFromContext(ctx)
userIP := sessionInbounds.Source.Address.IP()
err = d.router.RestrictionRule(restriction, userIP)
if err != nil {
errors.LogError(ctx, err)
}
}
if h := d.ohm.GetHandler(outTag); h != nil {
isPickRoute = 2
if route.GetRuleTag() == "" {

View file

@ -4,6 +4,7 @@ import (
"strings"
"github.com/xtls/xray-core/common/net"
route "github.com/xtls/xray-core/common/route"
"github.com/xtls/xray-core/features/routing"
)
@ -32,6 +33,10 @@ func (c routingContext) GetRuleTag() string {
return ""
}
func (c routingContext) GetRestriction() *route.Restriction {
return c.Restriction
}
// GetSkipDNSResolve is a mock implementation here to match the interface,
// SkipDNSResolve is set from dns module, no use if coming from a protobuf object?
// TODO: please confirm @Vigilans
@ -62,6 +67,7 @@ var fieldMap = map[string]func(*RoutingContext, routing.Route){
"attributes": func(s *RoutingContext, r routing.Route) { s.Attributes = r.GetAttributes() },
"outbound_group": func(s *RoutingContext, r routing.Route) { s.OutboundGroupTags = r.GetOutboundGroupTags() },
"outbound": func(s *RoutingContext, r routing.Route) { s.OutboundTag = r.GetOutboundTag() },
"restriction": func(s *RoutingContext, r routing.Route) { s.Restriction = r.GetRestriction() },
}
// AsProtobufMessage takes selectors of fields and returns a function to convert routing.Route to protobuf RoutingContext.

View file

@ -11,17 +11,16 @@ import (
)
type Rule struct {
Tag string
RuleTag string
*RoutingRule
Balancer *Balancer
Condition Condition
}
func (r *Rule) GetTag() (string, error) {
func (r *Rule) GetTargetTag() (string, error) {
if r.Balancer != nil {
return r.Balancer.PickOutbound()
}
return r.Tag, nil
return r.GetTag(), nil
}
// Apply checks rule matching of current routing context.

83
app/router/restriction.go Normal file
View file

@ -0,0 +1,83 @@
package router
import (
"time"
"github.com/xtls/xray-core/common/errors"
"github.com/xtls/xray-core/common/net"
route "github.com/xtls/xray-core/common/route"
)
// RestrictionRule implements routing.Router.
func (r *Router) RestrictionRule(restriction *route.Restriction, ip net.IP) error {
mask := uint32(128)
if ip.To4() != nil {
mask = 32
}
cidrList := []*CIDR{{Ip: ip, Prefix: mask}}
sourceIP := &GeoIP{Cidr: cidrList}
// If rule exists, process it
if r.RuleExists(restriction.Tag) {
for _, rule := range r.rules {
if rule.RuleTag == restriction.Tag {
if shouldCleanup(restriction) {
errors.LogWarning(r.ctx, "restriction cleanup -> ", restriction.Tag, " after ", restriction.CleanInterval, " seconds")
restriction.LastCleanup = time.Now().Unix()
return r.RemoveRule(restriction.Tag)
}
// If cleanup is not running, schedule it
scheduleCleanup(r, restriction)
// Check if IP already exists in the list
if ipExistsInRestriction(rule.RoutingRule.SourceGeoip, cidrList[0]) {
return errors.New(ip.String(), " already exists in restriction list.")
}
rule.RoutingRule.SourceGeoip = append(rule.RoutingRule.SourceGeoip, sourceIP)
r.RemoveRule(restriction.Tag)
return r.ReloadRules(&Config{Rule: []*RoutingRule{rule.RoutingRule}}, true)
}
}
}
// If rule does not exist, create a new one
newRule := &RoutingRule{
RuleTag: restriction.Tag,
TargetTag: &RoutingRule_Tag{Tag: restriction.OutboundTag},
SourceGeoip: []*GeoIP{sourceIP},
}
errors.LogWarning(r.ctx, "restrict IP -> ", ip.String(), " for route violation.")
return r.ReloadRules(&Config{Rule: []*RoutingRule{newRule}}, true)
}
func shouldCleanup(restriction *route.Restriction) bool {
return time.Now().Unix()-restriction.LastCleanup >= restriction.CleanInterval && restriction.LastCleanup != 0
}
func scheduleCleanup(r *Router, restriction *route.Restriction) {
if !r.isCleanupRunning && restriction.CleanInterval != 0 {
r.isCleanupRunning = true
restriction.LastCleanup = time.Now().Unix()
go time.AfterFunc(time.Duration(restriction.CleanInterval)*time.Second, func() {
r.RestrictionRule(restriction, nil)
r.isCleanupRunning = false
})
}
}
func ipExistsInRestriction(sourceGeoip []*GeoIP, newCIDR *CIDR) bool {
for _, source := range sourceGeoip {
for _, cidr := range source.GetCidr() {
if string(cidr.GetIp()) == string(newCIDR.GetIp()) {
return true
}
}
}
return false
}

View file

@ -6,6 +6,7 @@ import (
"github.com/xtls/xray-core/common"
"github.com/xtls/xray-core/common/errors"
route "github.com/xtls/xray-core/common/route"
"github.com/xtls/xray-core/common/serial"
"github.com/xtls/xray-core/core"
"github.com/xtls/xray-core/features/dns"
@ -21,10 +22,11 @@ type Router struct {
balancers map[string]*Balancer
dns dns.Client
ctx context.Context
ohm outbound.Manager
dispatcher routing.Dispatcher
mu sync.Mutex
ctx context.Context
ohm outbound.Manager
dispatcher routing.Dispatcher
mu sync.Mutex
isCleanupRunning bool
}
// Route is an implementation of routing.Route.
@ -33,6 +35,7 @@ type Route struct {
outboundGroupTags []string
outboundTag string
ruleTag string
restriction *route.Restriction
}
// Init initializes the Router.
@ -60,9 +63,8 @@ func (r *Router) Init(ctx context.Context, config *Config, d dns.Client, ohm out
return err
}
rr := &Rule{
Condition: cond,
Tag: rule.GetTag(),
RuleTag: rule.GetRuleTag(),
RoutingRule: rule,
Condition: cond,
}
btag := rule.GetBalancingTag()
if len(btag) > 0 {
@ -84,11 +86,11 @@ func (r *Router) PickRoute(ctx routing.Context) (routing.Route, error) {
if err != nil {
return nil, err
}
tag, err := rule.GetTag()
tag, err := rule.GetTargetTag()
if err != nil {
return nil, err
}
return &Route{Context: ctx, outboundTag: tag, ruleTag: rule.RuleTag}, nil
return &Route{Context: ctx, outboundTag: tag, ruleTag: rule.RuleTag, restriction: rule.Restriction}, nil
}
// AddRule implements routing.Router.
@ -134,9 +136,8 @@ func (r *Router) ReloadRules(config *Config, shouldAppend bool) error {
return err
}
rr := &Rule{
Condition: cond,
Tag: rule.GetTag(),
RuleTag: rule.GetRuleTag(),
RoutingRule: rule,
Condition: cond,
}
btag := rule.GetBalancingTag()
if len(btag) > 0 {
@ -242,6 +243,11 @@ func (r *Route) GetRuleTag() string {
return r.ruleTag
}
// GetRestriction implements routing.Route.
func (r *Route) GetRestriction() *route.Restriction {
return r.restriction
}
func init() {
common.Must(common.RegisterConfig((*Config)(nil), func(ctx context.Context, config interface{}) (interface{}, error) {
r := new(Router)

View file

@ -2,6 +2,8 @@ package routing
import (
"github.com/xtls/xray-core/common"
"github.com/xtls/xray-core/common/net"
"github.com/xtls/xray-core/common/route"
"github.com/xtls/xray-core/common/serial"
"github.com/xtls/xray-core/features"
)
@ -16,6 +18,7 @@ type Router interface {
PickRoute(ctx Context) (Route, error)
AddRule(config *serial.TypedMessage, shouldAppend bool) error
RemoveRule(tag string) error
RestrictionRule(restriction *route.Restriction, ip net.IP) error
}
// Route is the routing result of Router feature.
@ -33,6 +36,9 @@ type Route interface {
// GetRuleTag returns the matching rule tag for debugging if exists
GetRuleTag() string
// GetRestriction.
GetRestriction() *route.Restriction
}
// RouterType return the type of Router interface. Can be used to implement common.HasType.
@ -65,6 +71,11 @@ func (DefaultRouter) RemoveRule(tag string) error {
return common.ErrNoClue
}
// RestrictionRule implements Router.
func (DefaultRouter) RestrictionRule(restriction *route.Restriction, ip net.IP) error {
return common.ErrNoClue
}
// Start implements common.Runnable.
func (DefaultRouter) Start() error {
return nil

View file

@ -10,6 +10,7 @@ import (
"github.com/xtls/xray-core/common/errors"
"github.com/xtls/xray-core/common/net"
"github.com/xtls/xray-core/common/platform/filesystem"
"github.com/xtls/xray-core/common/route"
"github.com/xtls/xray-core/common/serial"
"google.golang.org/protobuf/proto"
)
@ -531,17 +532,18 @@ func ToCidrList(ips StringList) ([]*router.GeoIP, error) {
func parseFieldRule(msg json.RawMessage) (*router.RoutingRule, error) {
type RawFieldRule struct {
RouterRule
Domain *StringList `json:"domain"`
Domains *StringList `json:"domains"`
IP *StringList `json:"ip"`
Port *PortList `json:"port"`
Network *NetworkList `json:"network"`
SourceIP *StringList `json:"source"`
SourcePort *PortList `json:"sourcePort"`
User *StringList `json:"user"`
InboundTag *StringList `json:"inboundTag"`
Protocols *StringList `json:"protocol"`
Attributes map[string]string `json:"attrs"`
Domain *StringList `json:"domain"`
Domains *StringList `json:"domains"`
IP *StringList `json:"ip"`
Port *PortList `json:"port"`
Network *NetworkList `json:"network"`
SourceIP *StringList `json:"source"`
SourcePort *PortList `json:"sourcePort"`
User *StringList `json:"user"`
InboundTag *StringList `json:"inboundTag"`
Protocols *StringList `json:"protocol"`
Attributes map[string]string `json:"attrs"`
Restriction *route.Restriction `json:"restriction"`
}
rawFieldRule := new(RawFieldRule)
err := json.Unmarshal(msg, rawFieldRule)
@ -638,6 +640,14 @@ func parseFieldRule(msg json.RawMessage) (*router.RoutingRule, error) {
rule.Attributes = rawFieldRule.Attributes
}
if rawFieldRule.Restriction != nil {
// set parent outbound tag
if rawFieldRule.Restriction.OutboundTag == "" {
rawFieldRule.Restriction.OutboundTag = rawFieldRule.OutboundTag
}
rule.Restriction = rawFieldRule.Restriction
}
return rule, nil
}