otrie.go 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559
  1. /*
  2. * This Source Code Form is subject to the terms of the Mozilla Public
  3. * License, v. 2.0. If a copy of the MPL was not distributed with this
  4. * file, You can obtain one at https://mozilla.org/MPL/2.0/.
  5. */
  6. package otrie
  7. import (
  8. "fmt"
  9. "sort"
  10. "strings"
  11. "idio.link/go/netaddr/v2"
  12. )
  13. var onesToMask = [9]byte{0, 128, 192, 224, 240, 248, 252, 254, 255}
  14. func newCursor(start *oTrieNode) *cursor {
  15. initCap := 8 // Not too big, not too small?
  16. cur := &cursor{
  17. chain: make([]*oTrieNode, 1, initCap),
  18. coord: make([]int, 1, initCap),
  19. }
  20. cur.chain[0] = start
  21. return cur
  22. }
  23. func maskToOnes(m byte) int {
  24. for i, v := range onesToMask {
  25. if v == m {
  26. return i
  27. }
  28. }
  29. panic("BUG: otrie: maskToOnes()")
  30. }
  31. type cursor struct {
  32. chain []*oTrieNode
  33. coord []int
  34. }
  35. func (c *cursor) eof() bool {
  36. return c.depth() == 0 &&
  37. c.coord[0] == len(c.chain[0].next)
  38. }
  39. func (c *cursor) String() string {
  40. l := 0
  41. p := make([]string, 0, 4)
  42. for _, n := range c.chain[1:] {
  43. p = append(p, fmt.Sprintf("%d", n.octet))
  44. l += maskToOnes(n.mask)
  45. }
  46. return strings.Join(p, ".") + fmt.Sprintf("/%d", l)
  47. }
  48. func (c *cursor) netAddrs() []*netaddr.NetAddr {
  49. nas := make([]*netaddr.NetAddr, 0, 2)
  50. l := 0
  51. p := make([]byte, 0, 4)
  52. for _, n := range c.chain[1:] {
  53. p = append(p, n.octet)
  54. l += maskToOnes(n.mask)
  55. }
  56. for _, cut := range c.chain[c.depth()].cuts {
  57. tmp := make([]byte, len(p))
  58. copy(tmp, p)
  59. for i := 0; i < (cut - c.depth()); i++ {
  60. tmp = append(tmp, byte(0))
  61. }
  62. na := netaddr.NewNetAddr(tmp, l)
  63. nas = append(nas, na)
  64. }
  65. return nas
  66. }
  67. func (c *cursor) depth() int {
  68. return len(c.chain) - 1
  69. }
  70. func (c *cursor) descend() int {
  71. depth := c.depth()
  72. nextNodes := c.chain[depth].next
  73. if len(nextNodes) != 0 {
  74. nextNode := nextNodes[c.coord[depth]]
  75. c.coord = append(c.coord, 0)
  76. c.chain = append(c.chain, nextNode)
  77. }
  78. return c.depth()
  79. }
  80. func (c *cursor) ascend() int {
  81. depth := c.depth()
  82. c.coord[depth] = 0
  83. if depth == 0 {
  84. return depth
  85. }
  86. c.chain[depth] = nil
  87. c.coord = c.coord[0:depth]
  88. c.chain = c.chain[0:depth]
  89. c.coord[depth-1]++
  90. return c.depth()
  91. }
  92. func (c *cursor) curNode() *oTrieNode {
  93. return c.chain[c.depth()]
  94. }
  95. func (c *cursor) step() int {
  96. if c.coord[c.depth()] == len(c.curNode().next) {
  97. return c.ascend()
  98. }
  99. return c.descend()
  100. }
  101. func (c *cursor) stepMember() (depth int) {
  102. depth = c.depth()
  103. for !c.eof() && len(c.curNode().cuts) == 0 {
  104. depth = c.step()
  105. }
  106. return
  107. }
  108. func contains(o1, m1, o2, m2 byte) bool {
  109. return m1 <= m2 && o1&m1 == o2&m1
  110. }
  111. func overlaps(o1, m1, o2, m2 byte) bool {
  112. return contains(o1, m1, o2, m2) ||
  113. contains(o2, m2, o1, m1)
  114. }
  115. func (c *cursor) reset() {
  116. for i := 1; i < len(c.chain); i++ {
  117. c.chain = nil
  118. c.coord[i] = 0
  119. }
  120. c.coord[0] = 0
  121. }
  122. /* Breadth-first search of OTrie */
  123. func (c *cursor) search(
  124. p *netaddr.NetAddr,
  125. equiv ...func(byte, byte, byte, byte) bool,
  126. ) (match bool) {
  127. ones := p.Length()
  128. for !c.eof() {
  129. if c.curNode().hasCut(p.AddressSize()) {
  130. match = true
  131. return
  132. }
  133. if len(c.curNode().next) == 0 {
  134. return
  135. }
  136. ones -= min(ones, maskToOnes(c.curNode().mask))
  137. octet := p.Address()[c.depth()]
  138. mask := onesToMask[byte(min(ones, 8))]
  139. equivTest := overlaps
  140. if len(equiv) > 0 {
  141. equivTest = equiv[0]
  142. }
  143. matchTest := func(i int) bool {
  144. next := c.curNode().next[i]
  145. return !prefixLessThan(
  146. next.octet, next.mask, octet, mask,
  147. ) || equivTest(
  148. next.octet, next.mask, octet, mask,
  149. )
  150. }
  151. n := len(c.curNode().next)
  152. c.coord[c.depth()] = sort.Search(n, matchTest)
  153. if c.coord[c.depth()] == n {
  154. return
  155. }
  156. next := c.curNode().next[c.coord[c.depth()]]
  157. if !equivTest(next.octet, next.mask, octet, mask) {
  158. break
  159. }
  160. _ = c.descend()
  161. }
  162. return
  163. }
  164. /* Depth-first iteration over OTrie */
  165. func (c *cursor) iter() <-chan []*oTrieNode {
  166. ch := make(chan []*oTrieNode)
  167. go func() {
  168. defer close(ch)
  169. for !c.eof() {
  170. if c.coord[c.depth()] == len(c.curNode().next) {
  171. if len(c.curNode().cuts) != 0 {
  172. tmp := make([]*oTrieNode, c.depth())
  173. copy(tmp, c.chain[1:])
  174. ch <- tmp
  175. }
  176. }
  177. _ = c.step()
  178. }
  179. c.reset()
  180. }()
  181. return ch
  182. }
  183. type OTrie interface {
  184. AddPrefix(*netaddr.NetAddr) error
  185. Contains(*netaddr.NetAddr) bool
  186. Overlaps(*netaddr.NetAddr) bool
  187. OverlapsRange(*netaddr.NetAddr, *netaddr.NetAddr) bool
  188. Print() string
  189. }
  190. func NewOTrie() OTrie {
  191. return &oTrieNode{}
  192. }
  193. func newOTrieNode(octet, mask byte) *oTrieNode {
  194. return &oTrieNode{
  195. octet: octet,
  196. mask: mask,
  197. }
  198. }
  199. func min(a, b int) int {
  200. if a < b {
  201. return a
  202. }
  203. return b
  204. }
  205. /* Test against sheared isometry on byte prefixes */
  206. func prefixLessThan(o1, m1, o2, m2 byte) bool {
  207. /*
  208. We seek an isometry from a binary trie embedded in R²
  209. to the integral number line, with ones bits branching to
  210. the right and zeros bits branching to the left. When
  211. considering only the ones bits of an n-bit binary
  212. number, the usual decimal encoding suffices, as each ith
  213. bit, for 0 <= i <= n-1, shifts a point at the origin to
  214. the right by 2^i. However, the zeros bits have no
  215. effect, in this sense, as leading zeros do not
  216. contribute to the integer value, resulting in many
  217. binary strings occupying the same point. This mapping is
  218. well-formed and surjective (every integer has a binary
  219. encoding), but it is not injective.
  220. -> R 0 00 000 0000
  221. -> 0001
  222. -> 001 0010
  223. -> 0011
  224. -> 01 010 0100
  225. -> 0101
  226. -> 011 0110
  227. -> 0111
  228. -> 1 10 100 1000
  229. -> 1001
  230. -> 101 1010
  231. -> 1011
  232. -> 11 110 1100
  233. -> 1101
  234. -> 111 1110
  235. -> 1111
  236. Instead, by allowing each zero bit to provide a negative
  237. contribution, we ensure that every path (binary string)
  238. maps to a unique point on the integral line. To see that
  239. this is an isometry, project the nodes of a binary trie
  240. embedded in R², with branch width w = 2^(n-j-1) at rank
  241. j and arbitrary branch height h, onto the integral line.
  242. Here, we express the above isometry by simply
  243. subtracting the contributions of the place values of the
  244. zeros bits, which we obtain by ones complement (XOR).
  245. Masking ensures we map only the prefix.
  246. Finally, we note that projecting the usual geometry of a
  247. binary trie onto the integral line places each ancestor
  248. at the midpoint of its descendants. This creates a
  249. disagreement between the partial and total orders, which
  250. makes it difficult to recover the original partial order
  251. generated by the trie.
  252. -> 0000
  253. -> 000
  254. -> 0001
  255. -> 00
  256. -> 0010
  257. -> 001
  258. -> 0011
  259. -> 0
  260. -> 0100
  261. -> 010
  262. -> 0101
  263. -> 01
  264. -> 0110
  265. -> 011
  266. -> 0111
  267. 0/0
  268. -> 1000
  269. -> 100
  270. -> 1001
  271. -> 10
  272. -> 1010
  273. -> 101
  274. -> 1011
  275. -> 1
  276. -> 1100
  277. -> 110
  278. -> 1101
  279. -> 11
  280. -> 1110
  281. -> 111
  282. -> 1111
  283. Instead, we require that the total order < given by the
  284. proposed isometry have the property that it always agree
  285. with the partial order ≪ given by the binary trie. That
  286. is, given nodes a, b of a binary trie, if a ≪ b, then
  287. a < b. Note that the converse does not hold, nor would
  288. it be useful if it did.
  289. To accomplish this, we effectively apply a horizontal
  290. shear to the right on the binary trie such that the root
  291. of each subtree falls to the right of its descendants.
  292. -> 0000
  293. -> 0001
  294. -> 000
  295. -> 0010
  296. -> 0011
  297. -> 001
  298. -> 00
  299. -> 0100
  300. -> 0101
  301. -> 010
  302. -> 0110
  303. -> 0111
  304. -> 011
  305. -> 01
  306. -> 0
  307. -> 1000
  308. -> 1001
  309. -> 100
  310. -> 1010
  311. -> 1011
  312. -> 101
  313. -> 10
  314. -> 1100
  315. -> 1101
  316. -> 110
  317. -> 1110
  318. -> 1111
  319. -> 111
  320. -> 11
  321. -> 1
  322. 0/0
  323. In the unsheared isometry, the nodes we wish to reorder
  324. are specifically those which precede their descendants--
  325. that is, whenever a < b and b ≪ a, we would like to
  326. move a to the right of b so that b < a, thereby agreeing
  327. with the partial order. Since b ≪ a, it is the case
  328. that a represents a shorter prefix than b. Furthermore,
  329. since b is a descendant of a, b's prefix is guaranteed
  330. to share all of the bits of a. Conversely, if a and b
  331. share a prefix, then that prefix represents the most
  332. recent shared ancestor (least upper bound). And if the
  333. prefix shared by a and b is a's prefix, then it must be
  334. the case that a's prefix is shorter, and b is a
  335. descendant of a. Thus, b ≪ a.
  336. Therefore, we conclude that a must be reordered
  337. precisely when a < b and b shares a's prefix. Hence our
  338. assertion `f1 < f2 && o1 != o2&m1`. To reorder, we
  339. simply place prefixes with longer masks before prefixes
  340. with shorter ones, as given by `m1 > m2`.
  341. */
  342. f1 := int(o1&m1) - int((o1^byte(255))&m1)
  343. f2 := int(o2&m2) - int((o2^byte(255))&m2)
  344. return f1 < f2 && (o1 != o2&m1 || m1 > m2)
  345. }
  346. /*
  347. Every octet trie node represents the end of a prefix, but
  348. not all represented prefixes are intended to be members of
  349. the set. For example, when a /24 is added, antecedent
  350. nodes representing a /8 and /16 are also created:
  351. /8 -> /16 -> /24
  352. The `cuts` field is used to determine which nodes
  353. represent members of the set and which do not. Nodes with
  354. an empty `cuts` do not represent members. A node with one
  355. cut represents one member whose total size in octets is
  356. given by the cut. A node with cuts {4, 128} simultaneously
  357. represents a 4-byte member and a 128-byte member, both of
  358. which share the same prefix octets and prefix length.
  359. */
  360. type oTrieNode struct {
  361. octet byte
  362. mask byte
  363. next []*oTrieNode
  364. cuts []int
  365. }
  366. func (n *oTrieNode) hasCut(cut int) bool {
  367. for _, c := range n.cuts {
  368. if c == cut {
  369. return true
  370. }
  371. }
  372. return false
  373. }
  374. func (n *oTrieNode) setCut(cut int) {
  375. if n.cuts == nil {
  376. n.cuts = make([]int, 0, 2)
  377. }
  378. if n.hasCut(cut) {
  379. return
  380. }
  381. n.cuts = append(n.cuts, cut)
  382. }
  383. func (root *oTrieNode) Print() (output string) {
  384. maskToOnes := func(m byte) int {
  385. for i, v := range onesToMask {
  386. if v == m {
  387. return i
  388. }
  389. }
  390. panic("BUG: *oTrieNode.Print(): maskToOnes()")
  391. }
  392. cur := newCursor(root)
  393. for chain := range cur.iter() {
  394. l := 0
  395. p := make([]string, 0, len(chain))
  396. for _, o := range chain {
  397. p = append(p, fmt.Sprintf("%d", o.octet))
  398. l += maskToOnes(o.mask)
  399. }
  400. output += strings.Join(p, ".")
  401. for _, cut := range chain[len(chain)-1].cuts {
  402. for i := 0; i < (cut - len(chain)); i++ {
  403. output += ".0"
  404. }
  405. }
  406. output += fmt.Sprintf("/%d\n", l)
  407. }
  408. return
  409. }
  410. func (root *oTrieNode) String() string {
  411. return root.Print()
  412. }
  413. func (root *oTrieNode) AddPrefix(p *netaddr.NetAddr) error {
  414. nodeLessThan :=
  415. func(cur *oTrieNode) func(int, int) bool {
  416. return func(i, j int) bool {
  417. return prefixLessThan(
  418. cur.next[i].octet, cur.next[i].mask,
  419. cur.next[j].octet, cur.next[j].mask,
  420. )
  421. }
  422. }
  423. curNode := root
  424. ones := p.Length()
  425. for _, o := range p.Address() {
  426. if ones == 0 {
  427. break
  428. }
  429. mask := onesToMask[min(8, ones)]
  430. ones -= min(8, ones)
  431. octet := o & mask
  432. if curNode.next == nil {
  433. curNode.next = make([]*oTrieNode, 0, 64)
  434. }
  435. var nextNode *oTrieNode
  436. oIdx := sort.Search(len(curNode.next), func(i int) bool {
  437. return !prefixLessThan(
  438. curNode.next[i].octet,
  439. curNode.next[i].mask,
  440. octet,
  441. mask,
  442. )
  443. })
  444. switch {
  445. case oIdx != len(curNode.next) &&
  446. curNode.next[oIdx].octet == octet &&
  447. curNode.next[oIdx].mask == mask:
  448. nextNode = curNode.next[oIdx]
  449. if ones == 0 {
  450. nextNode.setCut(p.AddressSize())
  451. }
  452. default:
  453. nextNode = newOTrieNode(octet, mask)
  454. if ones == 0 {
  455. nextNode.setCut(p.AddressSize())
  456. }
  457. curNode.next = append(curNode.next, nextNode)
  458. if oIdx != len(curNode.next) {
  459. sort.Slice(curNode.next, nodeLessThan(curNode))
  460. }
  461. }
  462. curNode = nextNode
  463. }
  464. return nil
  465. }
  466. /* Test for overlap; after all, even if a given prefix does
  467. * not contain a certain range, if the prefix is a subset of
  468. * that range, then the range is relevant.
  469. */
  470. func (root *oTrieNode) Overlaps(p *netaddr.NetAddr) bool {
  471. return newCursor(root).search(p)
  472. }
  473. func (root *oTrieNode) Contains(p *netaddr.NetAddr) bool {
  474. return newCursor(root).search(p, contains)
  475. }
  476. /* Test address range ([a, b]) for overlap, ignoring address
  477. * lengths.
  478. */
  479. func (root *oTrieNode) OverlapsRange(
  480. a *netaddr.NetAddr,
  481. b *netaddr.NetAddr,
  482. ) (match bool) {
  483. if a.AddressSize() != b.AddressSize() {
  484. return
  485. }
  486. a.SetLength(a.AddressSize() * 8)
  487. b.SetLength(a.AddressSize() * 8)
  488. if netaddr.LessThan(b, a) {
  489. a, b = b, a
  490. }
  491. cur := newCursor(root)
  492. match = cur.search(a)
  493. if match { // Match on beginning of range
  494. return
  495. }
  496. if cur.eof() { // Beginning of range is right of prefixes
  497. return
  498. }
  499. // Current node is not a member of the set
  500. for !cur.eof() && !cur.curNode().hasCut(a.AddressSize()) {
  501. _ = cur.stepMember()
  502. }
  503. for _, cna := range cur.netAddrs() {
  504. if cna.AddressSize() == a.AddressSize() &&
  505. netaddr.LessThan(cna, b) ||
  506. cna.Overlaps(b) {
  507. match = true
  508. }
  509. }
  510. return
  511. }