/* * This Source Code Form is subject to the terms of the Mozilla Public * License, v. 2.0. If a copy of the MPL was not distributed with this * file, You can obtain one at https://mozilla.org/MPL/2.0/. */ package otrie import ( "fmt" "sort" "strings" "idio.link/go/netaddr/v2" ) var onesToMask = [9]byte{0, 128, 192, 224, 240, 248, 252, 254, 255} func newCursor(start *oTrieNode) *cursor { initCap := 8 // Not too big, not too small? cur := &cursor{ chain: make([]*oTrieNode, 1, initCap), coord: make([]int, 1, initCap), } cur.chain[0] = start return cur } func maskToOnes(m byte) int { for i, v := range onesToMask { if v == m { return i } } panic("BUG: otrie: maskToOnes()") } type cursor struct { chain []*oTrieNode coord []int } func (c *cursor) eof() bool { return c.depth() == 0 && c.coord[0] == len(c.chain[0].next) } func (c *cursor) String() string { l := 0 p := make([]string, 0, 4) for _, n := range c.chain[1:] { p = append(p, fmt.Sprintf("%d", n.octet)) l += maskToOnes(n.mask) } return strings.Join(p, ".") + fmt.Sprintf("/%d", l) } func (c *cursor) netAddrs() []*netaddr.NetAddr { nas := make([]*netaddr.NetAddr, 0, 2) l := 0 p := make([]byte, 0, 4) for _, n := range c.chain[1:] { p = append(p, n.octet) l += maskToOnes(n.mask) } for _, cut := range c.chain[c.depth()].cuts { tmp := make([]byte, len(p)) copy(tmp, p) for i := 0; i < (cut - c.depth()); i++ { tmp = append(tmp, byte(0)) } na := netaddr.NewNetAddr(tmp, l) nas = append(nas, na) } return nas } func (c *cursor) depth() int { return len(c.chain) - 1 } func (c *cursor) descend() int { depth := c.depth() nextNodes := c.chain[depth].next if len(nextNodes) != 0 { nextNode := nextNodes[c.coord[depth]] c.coord = append(c.coord, 0) c.chain = append(c.chain, nextNode) } return c.depth() } func (c *cursor) ascend() int { depth := c.depth() c.coord[depth] = 0 if depth == 0 { return depth } c.chain[depth] = nil c.coord = c.coord[0:depth] c.chain = c.chain[0:depth] c.coord[depth-1]++ return c.depth() } func (c *cursor) curNode() *oTrieNode { return c.chain[c.depth()] } func (c *cursor) step() int { if c.coord[c.depth()] == len(c.curNode().next) { return c.ascend() } return c.descend() } func (c *cursor) stepMember() (depth int) { depth = c.depth() for !c.eof() && len(c.curNode().cuts) == 0 { depth = c.step() } return } func contains(o1, m1, o2, m2 byte) bool { return m1 <= m2 && o1&m1 == o2&m1 } func overlaps(o1, m1, o2, m2 byte) bool { return contains(o1, m1, o2, m2) || contains(o2, m2, o1, m1) } func (c *cursor) reset() { for i := 1; i < len(c.chain); i++ { c.chain = nil c.coord[i] = 0 } c.coord[0] = 0 } /* Breadth-first search of OTrie */ func (c *cursor) search( p *netaddr.NetAddr, equiv ...func(byte, byte, byte, byte) bool, ) (match bool) { ones := p.Length() for !c.eof() { if c.curNode().hasCut(p.AddressSize()) { match = true return } if len(c.curNode().next) == 0 { return } ones -= min(ones, maskToOnes(c.curNode().mask)) octet := p.Address()[c.depth()] mask := onesToMask[byte(min(ones, 8))] equivTest := overlaps if len(equiv) > 0 { equivTest = equiv[0] } matchTest := func(i int) bool { next := c.curNode().next[i] return !prefixLessThan( next.octet, next.mask, octet, mask, ) || equivTest( next.octet, next.mask, octet, mask, ) } n := len(c.curNode().next) c.coord[c.depth()] = sort.Search(n, matchTest) if c.coord[c.depth()] == n { return } next := c.curNode().next[c.coord[c.depth()]] if !equivTest(next.octet, next.mask, octet, mask) { break } _ = c.descend() } return } /* Depth-first iteration over OTrie */ func (c *cursor) iter() <-chan []*oTrieNode { ch := make(chan []*oTrieNode) go func() { defer close(ch) for !c.eof() { if c.coord[c.depth()] == len(c.curNode().next) { if len(c.curNode().cuts) != 0 { tmp := make([]*oTrieNode, c.depth()) copy(tmp, c.chain[1:]) ch <- tmp } } _ = c.step() } c.reset() }() return ch } type OTrie interface { AddPrefix(*netaddr.NetAddr) error Contains(*netaddr.NetAddr) bool Overlaps(*netaddr.NetAddr) bool OverlapsRange(*netaddr.NetAddr, *netaddr.NetAddr) bool Print() string } func NewOTrie() OTrie { return &oTrieNode{} } func newOTrieNode(octet, mask byte) *oTrieNode { return &oTrieNode{ octet: octet, mask: mask, } } func min(a, b int) int { if a < b { return a } return b } /* Test against sheared isometry on byte prefixes */ func prefixLessThan(o1, m1, o2, m2 byte) bool { /* We seek an isometry from a binary trie embedded in R² to the integral number line, with ones bits branching to the right and zeros bits branching to the left. When considering only the ones bits of an n-bit binary number, the usual decimal encoding suffices, as each ith bit, for 0 <= i <= n-1, shifts a point at the origin to the right by 2^i. However, the zeros bits have no effect, in this sense, as leading zeros do not contribute to the integer value, resulting in many binary strings occupying the same point. This mapping is well-formed and surjective (every integer has a binary encoding), but it is not injective. -> R 0 00 000 0000 -> 0001 -> 001 0010 -> 0011 -> 01 010 0100 -> 0101 -> 011 0110 -> 0111 -> 1 10 100 1000 -> 1001 -> 101 1010 -> 1011 -> 11 110 1100 -> 1101 -> 111 1110 -> 1111 Instead, by allowing each zero bit to provide a negative contribution, we ensure that every path (binary string) maps to a unique point on the integral line. To see that this is an isometry, project the nodes of a binary trie embedded in R², with branch width w = 2^(n-j-1) at rank j and arbitrary branch height h, onto the integral line. Here, we express the above isometry by simply subtracting the contributions of the place values of the zeros bits, which we obtain by ones complement (XOR). Masking ensures we map only the prefix. Finally, we note that projecting the usual geometry of a binary trie onto the integral line places each ancestor at the midpoint of its descendants. This creates a disagreement between the partial and total orders, which makes it difficult to recover the original partial order generated by the trie. -> 0000 -> 000 -> 0001 -> 00 -> 0010 -> 001 -> 0011 -> 0 -> 0100 -> 010 -> 0101 -> 01 -> 0110 -> 011 -> 0111 0/0 -> 1000 -> 100 -> 1001 -> 10 -> 1010 -> 101 -> 1011 -> 1 -> 1100 -> 110 -> 1101 -> 11 -> 1110 -> 111 -> 1111 Instead, we require that the total order < given by the proposed isometry have the property that it always agree with the partial order ≪ given by the binary trie. That is, given nodes a, b of a binary trie, if a ≪ b, then a < b. Note that the converse does not hold, nor would it be useful if it did. To accomplish this, we effectively apply a horizontal shear to the right on the binary trie such that the root of each subtree falls to the right of its descendants. -> 0000 -> 0001 -> 000 -> 0010 -> 0011 -> 001 -> 00 -> 0100 -> 0101 -> 010 -> 0110 -> 0111 -> 011 -> 01 -> 0 -> 1000 -> 1001 -> 100 -> 1010 -> 1011 -> 101 -> 10 -> 1100 -> 1101 -> 110 -> 1110 -> 1111 -> 111 -> 11 -> 1 0/0 In the unsheared isometry, the nodes we wish to reorder are specifically those which precede their descendants-- that is, whenever a < b and b ≪ a, we would like to move a to the right of b so that b < a, thereby agreeing with the partial order. Since b ≪ a, it is the case that a represents a shorter prefix than b. Furthermore, since b is a descendant of a, b's prefix is guaranteed to share all of the bits of a. Conversely, if a and b share a prefix, then that prefix represents the most recent shared ancestor (least upper bound). And if the prefix shared by a and b is a's prefix, then it must be the case that a's prefix is shorter, and b is a descendant of a. Thus, b ≪ a. Therefore, we conclude that a must be reordered precisely when a < b and b shares a's prefix. Hence our assertion `f1 < f2 && o1 != o2&m1`. To reorder, we simply place prefixes with longer masks before prefixes with shorter ones, as given by `m1 > m2`. */ f1 := int(o1&m1) - int((o1^byte(255))&m1) f2 := int(o2&m2) - int((o2^byte(255))&m2) return f1 < f2 && (o1 != o2&m1 || m1 > m2) } /* Every octet trie node represents the end of a prefix, but not all represented prefixes are intended to be members of the set. For example, when a /24 is added, antecedent nodes representing a /8 and /16 are also created: /8 -> /16 -> /24 The `cuts` field is used to determine which nodes represent members of the set and which do not. Nodes with an empty `cuts` do not represent members. A node with one cut represents one member whose total size in octets is given by the cut. A node with cuts {4, 128} simultaneously represents a 4-byte member and a 128-byte member, both of which share the same prefix octets and prefix length. */ type oTrieNode struct { octet byte mask byte next []*oTrieNode cuts []int } func (n *oTrieNode) hasCut(cut int) bool { for _, c := range n.cuts { if c == cut { return true } } return false } func (n *oTrieNode) setCut(cut int) { if n.cuts == nil { n.cuts = make([]int, 0, 2) } if n.hasCut(cut) { return } n.cuts = append(n.cuts, cut) } func (root *oTrieNode) Print() (output string) { maskToOnes := func(m byte) int { for i, v := range onesToMask { if v == m { return i } } panic("BUG: *oTrieNode.Print(): maskToOnes()") } cur := newCursor(root) for chain := range cur.iter() { l := 0 p := make([]string, 0, len(chain)) for _, o := range chain { p = append(p, fmt.Sprintf("%d", o.octet)) l += maskToOnes(o.mask) } output += strings.Join(p, ".") for _, cut := range chain[len(chain)-1].cuts { for i := 0; i < (cut - len(chain)); i++ { output += ".0" } } output += fmt.Sprintf("/%d\n", l) } return } func (root *oTrieNode) String() string { return root.Print() } func (root *oTrieNode) AddPrefix(p *netaddr.NetAddr) error { nodeLessThan := func(cur *oTrieNode) func(int, int) bool { return func(i, j int) bool { return prefixLessThan( cur.next[i].octet, cur.next[i].mask, cur.next[j].octet, cur.next[j].mask, ) } } curNode := root ones := p.Length() for _, o := range p.Address() { if ones == 0 { break } mask := onesToMask[min(8, ones)] ones -= min(8, ones) octet := o & mask if curNode.next == nil { curNode.next = make([]*oTrieNode, 0, 64) } var nextNode *oTrieNode oIdx := sort.Search(len(curNode.next), func(i int) bool { return !prefixLessThan( curNode.next[i].octet, curNode.next[i].mask, octet, mask, ) }) switch { case oIdx != len(curNode.next) && curNode.next[oIdx].octet == octet && curNode.next[oIdx].mask == mask: nextNode = curNode.next[oIdx] if ones == 0 { nextNode.setCut(p.AddressSize()) } default: nextNode = newOTrieNode(octet, mask) if ones == 0 { nextNode.setCut(p.AddressSize()) } curNode.next = append(curNode.next, nextNode) if oIdx != len(curNode.next) { sort.Slice(curNode.next, nodeLessThan(curNode)) } } curNode = nextNode } return nil } /* Test for overlap; after all, even if a given prefix does * not contain a certain range, if the prefix is a subset of * that range, then the range is relevant. */ func (root *oTrieNode) Overlaps(p *netaddr.NetAddr) bool { return newCursor(root).search(p) } func (root *oTrieNode) Contains(p *netaddr.NetAddr) bool { return newCursor(root).search(p, contains) } /* Test address range ([a, b]) for overlap, ignoring address * lengths. */ func (root *oTrieNode) OverlapsRange( a *netaddr.NetAddr, b *netaddr.NetAddr, ) (match bool) { if a.AddressSize() != b.AddressSize() { return } a.SetLength(a.AddressSize() * 8) b.SetLength(a.AddressSize() * 8) if netaddr.LessThan(b, a) { a, b = b, a } cur := newCursor(root) match = cur.search(a) if match { // Match on beginning of range return } if cur.eof() { // Beginning of range is right of prefixes return } // Current node is not a member of the set for !cur.eof() && !cur.curNode().hasCut(a.AddressSize()) { _ = cur.stepMember() } for _, cna := range cur.netAddrs() { if cna.AddressSize() == a.AddressSize() && netaddr.LessThan(cna, b) || cna.Overlaps(b) { match = true } } return }