123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559 |
- /*
- * This Source Code Form is subject to the terms of the Mozilla Public
- * License, v. 2.0. If a copy of the MPL was not distributed with this
- * file, You can obtain one at https://mozilla.org/MPL/2.0/.
- */
- package otrie
- import (
- "fmt"
- "sort"
- "strings"
- "idio.link/go/netaddr/v2"
- )
- var onesToMask = [9]byte{0, 128, 192, 224, 240, 248, 252, 254, 255}
- func newCursor(start *oTrieNode) *cursor {
- initCap := 8 // Not too big, not too small?
- cur := &cursor{
- chain: make([]*oTrieNode, 1, initCap),
- coord: make([]int, 1, initCap),
- }
- cur.chain[0] = start
- return cur
- }
- func maskToOnes(m byte) int {
- for i, v := range onesToMask {
- if v == m {
- return i
- }
- }
- panic("BUG: otrie: maskToOnes()")
- }
- type cursor struct {
- chain []*oTrieNode
- coord []int
- }
- func (c *cursor) eof() bool {
- return c.depth() == 0 &&
- c.coord[0] == len(c.chain[0].next)
- }
- func (c *cursor) String() string {
- l := 0
- p := make([]string, 0, 4)
- for _, n := range c.chain[1:] {
- p = append(p, fmt.Sprintf("%d", n.octet))
- l += maskToOnes(n.mask)
- }
- return strings.Join(p, ".") + fmt.Sprintf("/%d", l)
- }
- func (c *cursor) netAddrs() []*netaddr.NetAddr {
- nas := make([]*netaddr.NetAddr, 0, 2)
- l := 0
- p := make([]byte, 0, 4)
- for _, n := range c.chain[1:] {
- p = append(p, n.octet)
- l += maskToOnes(n.mask)
- }
- for _, cut := range c.chain[c.depth()].cuts {
- tmp := make([]byte, len(p))
- copy(tmp, p)
- for i := 0; i < (cut - c.depth()); i++ {
- tmp = append(tmp, byte(0))
- }
- na := netaddr.NewNetAddr(tmp, l)
- nas = append(nas, na)
- }
- return nas
- }
- func (c *cursor) depth() int {
- return len(c.chain) - 1
- }
- func (c *cursor) descend() int {
- depth := c.depth()
- nextNodes := c.chain[depth].next
- if len(nextNodes) != 0 {
- nextNode := nextNodes[c.coord[depth]]
- c.coord = append(c.coord, 0)
- c.chain = append(c.chain, nextNode)
- }
- return c.depth()
- }
- func (c *cursor) ascend() int {
- depth := c.depth()
- c.coord[depth] = 0
- if depth == 0 {
- return depth
- }
- c.chain[depth] = nil
- c.coord = c.coord[0:depth]
- c.chain = c.chain[0:depth]
- c.coord[depth-1]++
- return c.depth()
- }
- func (c *cursor) curNode() *oTrieNode {
- return c.chain[c.depth()]
- }
- func (c *cursor) step() int {
- if c.coord[c.depth()] == len(c.curNode().next) {
- return c.ascend()
- }
- return c.descend()
- }
- func (c *cursor) stepMember() (depth int) {
- depth = c.depth()
- for !c.eof() && len(c.curNode().cuts) == 0 {
- depth = c.step()
- }
- return
- }
- func contains(o1, m1, o2, m2 byte) bool {
- return m1 <= m2 && o1&m1 == o2&m1
- }
- func overlaps(o1, m1, o2, m2 byte) bool {
- return contains(o1, m1, o2, m2) ||
- contains(o2, m2, o1, m1)
- }
- func (c *cursor) reset() {
- for i := 1; i < len(c.chain); i++ {
- c.chain = nil
- c.coord[i] = 0
- }
- c.coord[0] = 0
- }
- /* Breadth-first search of OTrie */
- func (c *cursor) search(
- p *netaddr.NetAddr,
- equiv ...func(byte, byte, byte, byte) bool,
- ) (match bool) {
- ones := p.Length()
- for !c.eof() {
- if c.curNode().hasCut(p.AddressSize()) {
- match = true
- return
- }
- if len(c.curNode().next) == 0 {
- return
- }
- ones -= min(ones, maskToOnes(c.curNode().mask))
- octet := p.Address()[c.depth()]
- mask := onesToMask[byte(min(ones, 8))]
- equivTest := overlaps
- if len(equiv) > 0 {
- equivTest = equiv[0]
- }
- matchTest := func(i int) bool {
- next := c.curNode().next[i]
- return !prefixLessThan(
- next.octet, next.mask, octet, mask,
- ) || equivTest(
- next.octet, next.mask, octet, mask,
- )
- }
- n := len(c.curNode().next)
- c.coord[c.depth()] = sort.Search(n, matchTest)
- if c.coord[c.depth()] == n {
- return
- }
- next := c.curNode().next[c.coord[c.depth()]]
- if !equivTest(next.octet, next.mask, octet, mask) {
- break
- }
- _ = c.descend()
- }
- return
- }
- /* Depth-first iteration over OTrie */
- func (c *cursor) iter() <-chan []*oTrieNode {
- ch := make(chan []*oTrieNode)
- go func() {
- defer close(ch)
- for !c.eof() {
- if c.coord[c.depth()] == len(c.curNode().next) {
- if len(c.curNode().cuts) != 0 {
- tmp := make([]*oTrieNode, c.depth())
- copy(tmp, c.chain[1:])
- ch <- tmp
- }
- }
- _ = c.step()
- }
- c.reset()
- }()
- return ch
- }
- type OTrie interface {
- AddPrefix(*netaddr.NetAddr) error
- Contains(*netaddr.NetAddr) bool
- Overlaps(*netaddr.NetAddr) bool
- OverlapsRange(*netaddr.NetAddr, *netaddr.NetAddr) bool
- Print() string
- }
- func NewOTrie() OTrie {
- return &oTrieNode{}
- }
- func newOTrieNode(octet, mask byte) *oTrieNode {
- return &oTrieNode{
- octet: octet,
- mask: mask,
- }
- }
- func min(a, b int) int {
- if a < b {
- return a
- }
- return b
- }
- /* Test against sheared isometry on byte prefixes */
- func prefixLessThan(o1, m1, o2, m2 byte) bool {
- /*
- We seek an isometry from a binary trie embedded in R²
- to the integral number line, with ones bits branching to
- the right and zeros bits branching to the left. When
- considering only the ones bits of an n-bit binary
- number, the usual decimal encoding suffices, as each ith
- bit, for 0 <= i <= n-1, shifts a point at the origin to
- the right by 2^i. However, the zeros bits have no
- effect, in this sense, as leading zeros do not
- contribute to the integer value, resulting in many
- binary strings occupying the same point. This mapping is
- well-formed and surjective (every integer has a binary
- encoding), but it is not injective.
- -> R 0 00 000 0000
- -> 0001
- -> 001 0010
- -> 0011
- -> 01 010 0100
- -> 0101
- -> 011 0110
- -> 0111
- -> 1 10 100 1000
- -> 1001
- -> 101 1010
- -> 1011
- -> 11 110 1100
- -> 1101
- -> 111 1110
- -> 1111
- Instead, by allowing each zero bit to provide a negative
- contribution, we ensure that every path (binary string)
- maps to a unique point on the integral line. To see that
- this is an isometry, project the nodes of a binary trie
- embedded in R², with branch width w = 2^(n-j-1) at rank
- j and arbitrary branch height h, onto the integral line.
- Here, we express the above isometry by simply
- subtracting the contributions of the place values of the
- zeros bits, which we obtain by ones complement (XOR).
- Masking ensures we map only the prefix.
- Finally, we note that projecting the usual geometry of a
- binary trie onto the integral line places each ancestor
- at the midpoint of its descendants. This creates a
- disagreement between the partial and total orders, which
- makes it difficult to recover the original partial order
- generated by the trie.
- -> 0000
- -> 000
- -> 0001
- -> 00
- -> 0010
- -> 001
- -> 0011
- -> 0
- -> 0100
- -> 010
- -> 0101
- -> 01
- -> 0110
- -> 011
- -> 0111
- 0/0
- -> 1000
- -> 100
- -> 1001
- -> 10
- -> 1010
- -> 101
- -> 1011
- -> 1
- -> 1100
- -> 110
- -> 1101
- -> 11
- -> 1110
- -> 111
- -> 1111
- Instead, we require that the total order < given by the
- proposed isometry have the property that it always agree
- with the partial order ≪ given by the binary trie. That
- is, given nodes a, b of a binary trie, if a ≪ b, then
- a < b. Note that the converse does not hold, nor would
- it be useful if it did.
- To accomplish this, we effectively apply a horizontal
- shear to the right on the binary trie such that the root
- of each subtree falls to the right of its descendants.
- -> 0000
- -> 0001
- -> 000
- -> 0010
- -> 0011
- -> 001
- -> 00
- -> 0100
- -> 0101
- -> 010
- -> 0110
- -> 0111
- -> 011
- -> 01
- -> 0
- -> 1000
- -> 1001
- -> 100
- -> 1010
- -> 1011
- -> 101
- -> 10
- -> 1100
- -> 1101
- -> 110
- -> 1110
- -> 1111
- -> 111
- -> 11
- -> 1
- 0/0
- In the unsheared isometry, the nodes we wish to reorder
- are specifically those which precede their descendants--
- that is, whenever a < b and b ≪ a, we would like to
- move a to the right of b so that b < a, thereby agreeing
- with the partial order. Since b ≪ a, it is the case
- that a represents a shorter prefix than b. Furthermore,
- since b is a descendant of a, b's prefix is guaranteed
- to share all of the bits of a. Conversely, if a and b
- share a prefix, then that prefix represents the most
- recent shared ancestor (least upper bound). And if the
- prefix shared by a and b is a's prefix, then it must be
- the case that a's prefix is shorter, and b is a
- descendant of a. Thus, b ≪ a.
- Therefore, we conclude that a must be reordered
- precisely when a < b and b shares a's prefix. Hence our
- assertion `f1 < f2 && o1 != o2&m1`. To reorder, we
- simply place prefixes with longer masks before prefixes
- with shorter ones, as given by `m1 > m2`.
- */
- f1 := int(o1&m1) - int((o1^byte(255))&m1)
- f2 := int(o2&m2) - int((o2^byte(255))&m2)
- return f1 < f2 && (o1 != o2&m1 || m1 > m2)
- }
- /*
- Every octet trie node represents the end of a prefix, but
- not all represented prefixes are intended to be members of
- the set. For example, when a /24 is added, antecedent
- nodes representing a /8 and /16 are also created:
- /8 -> /16 -> /24
- The `cuts` field is used to determine which nodes
- represent members of the set and which do not. Nodes with
- an empty `cuts` do not represent members. A node with one
- cut represents one member whose total size in octets is
- given by the cut. A node with cuts {4, 128} simultaneously
- represents a 4-byte member and a 128-byte member, both of
- which share the same prefix octets and prefix length.
- */
- type oTrieNode struct {
- octet byte
- mask byte
- next []*oTrieNode
- cuts []int
- }
- func (n *oTrieNode) hasCut(cut int) bool {
- for _, c := range n.cuts {
- if c == cut {
- return true
- }
- }
- return false
- }
- func (n *oTrieNode) setCut(cut int) {
- if n.cuts == nil {
- n.cuts = make([]int, 0, 2)
- }
- if n.hasCut(cut) {
- return
- }
- n.cuts = append(n.cuts, cut)
- }
- func (root *oTrieNode) Print() (output string) {
- maskToOnes := func(m byte) int {
- for i, v := range onesToMask {
- if v == m {
- return i
- }
- }
- panic("BUG: *oTrieNode.Print(): maskToOnes()")
- }
- cur := newCursor(root)
- for chain := range cur.iter() {
- l := 0
- p := make([]string, 0, len(chain))
- for _, o := range chain {
- p = append(p, fmt.Sprintf("%d", o.octet))
- l += maskToOnes(o.mask)
- }
- output += strings.Join(p, ".")
- for _, cut := range chain[len(chain)-1].cuts {
- for i := 0; i < (cut - len(chain)); i++ {
- output += ".0"
- }
- }
- output += fmt.Sprintf("/%d\n", l)
- }
- return
- }
- func (root *oTrieNode) String() string {
- return root.Print()
- }
- func (root *oTrieNode) AddPrefix(p *netaddr.NetAddr) error {
- nodeLessThan :=
- func(cur *oTrieNode) func(int, int) bool {
- return func(i, j int) bool {
- return prefixLessThan(
- cur.next[i].octet, cur.next[i].mask,
- cur.next[j].octet, cur.next[j].mask,
- )
- }
- }
- curNode := root
- ones := p.Length()
- for _, o := range p.Address() {
- if ones == 0 {
- break
- }
- mask := onesToMask[min(8, ones)]
- ones -= min(8, ones)
- octet := o & mask
- if curNode.next == nil {
- curNode.next = make([]*oTrieNode, 0, 64)
- }
- var nextNode *oTrieNode
- oIdx := sort.Search(len(curNode.next), func(i int) bool {
- return !prefixLessThan(
- curNode.next[i].octet,
- curNode.next[i].mask,
- octet,
- mask,
- )
- })
- switch {
- case oIdx != len(curNode.next) &&
- curNode.next[oIdx].octet == octet &&
- curNode.next[oIdx].mask == mask:
- nextNode = curNode.next[oIdx]
- if ones == 0 {
- nextNode.setCut(p.AddressSize())
- }
- default:
- nextNode = newOTrieNode(octet, mask)
- if ones == 0 {
- nextNode.setCut(p.AddressSize())
- }
- curNode.next = append(curNode.next, nextNode)
- if oIdx != len(curNode.next) {
- sort.Slice(curNode.next, nodeLessThan(curNode))
- }
- }
- curNode = nextNode
- }
- return nil
- }
- /* Test for overlap; after all, even if a given prefix does
- * not contain a certain range, if the prefix is a subset of
- * that range, then the range is relevant.
- */
- func (root *oTrieNode) Overlaps(p *netaddr.NetAddr) bool {
- return newCursor(root).search(p)
- }
- func (root *oTrieNode) Contains(p *netaddr.NetAddr) bool {
- return newCursor(root).search(p, contains)
- }
- /* Test address range ([a, b]) for overlap, ignoring address
- * lengths.
- */
- func (root *oTrieNode) OverlapsRange(
- a *netaddr.NetAddr,
- b *netaddr.NetAddr,
- ) (match bool) {
- if a.AddressSize() != b.AddressSize() {
- return
- }
- a.SetLength(a.AddressSize() * 8)
- b.SetLength(a.AddressSize() * 8)
- if netaddr.LessThan(b, a) {
- a, b = b, a
- }
- cur := newCursor(root)
- match = cur.search(a)
- if match { // Match on beginning of range
- return
- }
- if cur.eof() { // Beginning of range is right of prefixes
- return
- }
- // Current node is not a member of the set
- for !cur.eof() && !cur.curNode().hasCut(a.AddressSize()) {
- _ = cur.stepMember()
- }
- for _, cna := range cur.netAddrs() {
- if cna.AddressSize() == a.AddressSize() &&
- netaddr.LessThan(cna, b) ||
- cna.Overlaps(b) {
- match = true
- }
- }
- return
- }
|