netaddr.go 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283
  1. /*
  2. * This Source Code Form is subject to the terms of the Mozilla Public
  3. * License, v. 2.0. If a copy of the MPL was not distributed with this
  4. * file, You can obtain one at https://mozilla.org/MPL/2.0/.
  5. */
  6. package netaddr
  7. import (
  8. "bytes"
  9. "fmt"
  10. "io"
  11. "net"
  12. "strconv"
  13. )
  14. type NetAddr struct {
  15. Address []byte
  16. Length byte
  17. }
  18. func applyMask(address, mask []byte) []byte {
  19. var ip net.IP = net.IP(address)
  20. return ip.Mask(mask)
  21. }
  22. func (na *NetAddr) AddressSize() int {
  23. return len(na.Address)
  24. }
  25. func (na *NetAddr) Mask() []byte {
  26. onesToMask := [9]byte{0, 128, 192, 224, 240, 248, 252, 254, 255}
  27. min := func(a, b byte) byte {
  28. if a < b {
  29. return a
  30. }
  31. return b
  32. }
  33. ones := na.Length
  34. mask := make([]byte, 0, na.AddressSize())
  35. for i := 0; i < na.AddressSize(); i++ {
  36. olen := min(8, ones)
  37. ones -= olen
  38. mask = append(mask, onesToMask[int(olen)])
  39. }
  40. return mask
  41. }
  42. func (na *NetAddr) FirstAddress() *NetAddr {
  43. bitSize := na.AddressSize() * 8
  44. length := na.Length
  45. mask := net.CIDRMask(int(length), bitSize)
  46. addressableSlice := na.Address
  47. first0 := applyMask(addressableSlice, mask)
  48. first := make([]byte, na.AddressSize())
  49. copy(first[:], first0)
  50. return &NetAddr{Address: first, Length: length}
  51. }
  52. func (na *NetAddr) LastAddress() *NetAddr {
  53. bitSize := na.AddressSize() * 8
  54. length := na.Length
  55. mask := net.CIDRMask(int(length), bitSize)
  56. allOnes := net.CIDRMask(bitSize, bitSize)
  57. invMask := byteXor(mask, allOnes)
  58. lastAddr := byteOr(na.FirstAddress().Address, invMask)
  59. return &NetAddr{Address: lastAddr, Length: length}
  60. }
  61. func (na *NetAddr) String() string {
  62. addr := net.IP(na.Address).String()
  63. length := fmt.Sprintf("%d", na.Length)
  64. return addr + "/" + length
  65. }
  66. func (na *NetAddr) IsRFC1918() bool {
  67. return IsSubset(na, &NetAddr{
  68. Address: []byte{10, 0, 0, 0},
  69. Length: byte(8),
  70. }) || IsSubset(na, &NetAddr{
  71. Address: []byte{172, 16, 0, 0},
  72. Length: byte(12),
  73. }) || IsSubset(na, &NetAddr{
  74. Address: []byte{192, 168, 0, 0},
  75. Length: byte(16),
  76. })
  77. }
  78. func byteOp(u, v []byte, f func(byte, byte) byte) []byte {
  79. w := make([]byte, len(u), len(u))
  80. for i := 0; i < len(u); i++ {
  81. w[i] = f(u[i], v[i])
  82. }
  83. return w
  84. }
  85. func byteOr(u, v []byte) []byte {
  86. return byteOp(u, v, func(ui, vi byte) byte {
  87. return ui | vi
  88. })
  89. }
  90. func byteXor(u, v []byte) []byte {
  91. return byteOp(u, v, func(ui, vi byte) byte {
  92. return ui ^ vi
  93. })
  94. }
  95. func SetLength(na *NetAddr, length byte) *NetAddr {
  96. min := func(a, b byte) byte {
  97. if a < b {
  98. return a
  99. }
  100. return b
  101. }
  102. return &NetAddr{
  103. Address: na.Address,
  104. Length: min(length, byte(na.AddressSize()*8)),
  105. }
  106. }
  107. /* Test against skewed total order on byte-size prefixes */
  108. func octetLessThan(o1, m1, o2, m2 byte) bool {
  109. f1 := int(o1&m1) - int((o1^byte(255))&m1)
  110. f2 := int(o2&m2) - int((o2^byte(255))&m2)
  111. return f1 < f2 && !(o1&m1 == o2&m1 && m1 <= m2)
  112. }
  113. func LessThan(na1 *NetAddr, na2 *NetAddr) (less bool) {
  114. if na1.AddressSize() != na2.AddressSize() {
  115. return
  116. }
  117. mask1 := na1.Mask()
  118. mask2 := na2.Mask()
  119. for i := 0; i < na1.AddressSize(); i++ {
  120. if octetLessThan(
  121. na1.Address[i], mask1[i],
  122. na2.Address[i], mask2[i],
  123. ) {
  124. less = true
  125. break
  126. }
  127. if na1.Address[i] == na2.Address[i] &&
  128. mask1[i] == mask2[i] {
  129. continue
  130. }
  131. break
  132. }
  133. return
  134. }
  135. func IsEqual(na1, na2 *NetAddr) bool {
  136. return na1.Length == na2.Length &&
  137. bytes.Equal(na1.Address, na2.Address)
  138. }
  139. func Contains(na1, na2 *NetAddr) bool {
  140. tmp := SetLength(na2, na1.Length)
  141. return na1.Length <= na2.Length &&
  142. IsEqual(na1.FirstAddress(), tmp.FirstAddress())
  143. }
  144. func IsSubset(na1, na2 *NetAddr) bool {
  145. tmp := SetLength(na1, na2.Length)
  146. return na2.Length <= na1.Length &&
  147. IsEqual(na2.FirstAddress(), tmp.FirstAddress())
  148. }
  149. func Overlaps(na1, na2 *NetAddr) bool {
  150. return Contains(na1, na2) || Contains(na2, na1)
  151. }
  152. func parseTruncatedIPv4(s string) (na *NetAddr, err error) {
  153. na = &NetAddr{Address: make([]byte, 4)}
  154. r := bytes.NewBufferString(s)
  155. buf := new(bytes.Buffer)
  156. i := 0
  157. onLen := false
  158. for i < 4 {
  159. var b byte
  160. b, err = r.ReadByte()
  161. switch {
  162. case err == io.EOF:
  163. na.Length = byte(8*i - 8)
  164. if onLen {
  165. var o int
  166. o, err = strconv.Atoi(buf.String())
  167. if err != nil {
  168. return
  169. }
  170. if o < 0 || 32 < o {
  171. err = fmt.Errorf("length out of range: %v", o)
  172. return
  173. }
  174. na.Length = byte(o)
  175. return
  176. }
  177. err = nil
  178. return
  179. case err != nil:
  180. return
  181. case '0' <= b && b <= '9':
  182. buf.WriteByte(b)
  183. continue
  184. case b == '.' || b == '/':
  185. var o int
  186. o, err = strconv.Atoi(buf.String())
  187. if err != nil {
  188. return
  189. }
  190. if o < 0 || 255 < o {
  191. err = fmt.Errorf("octet out of range: %v", o)
  192. return
  193. }
  194. na.Address[i] = byte(o)
  195. i++
  196. buf = new(bytes.Buffer)
  197. if b == '/' {
  198. onLen = true
  199. }
  200. }
  201. }
  202. err = fmt.Errorf("input is longer than expected")
  203. return
  204. }
  205. func IP(s string) (na *NetAddr, err error) {
  206. var ip net.IP
  207. var ipNet *net.IPNet
  208. ip, ipNet, err = net.ParseCIDR(s)
  209. switch err.(type) {
  210. case nil:
  211. case *net.ParseError:
  212. ip = net.ParseIP(s)
  213. if ip == nil {
  214. na, err = parseTruncatedIPv4(s)
  215. if err != nil {
  216. err = fmt.Errorf("invalid ip address %v: %v", s, err)
  217. }
  218. return
  219. }
  220. if test := ip.To4(); test != nil {
  221. ipNet = &net.IPNet{
  222. IP: ip,
  223. Mask: net.IPMask([]byte{255, 255, 255, 255}),
  224. }
  225. } else if test := ip.To16(); test != nil {
  226. ipNet = &net.IPNet{
  227. IP: ip,
  228. Mask: net.IPMask(
  229. []byte{
  230. 0xff, 0xff, 0xff, 0xff,
  231. 0xff, 0xff, 0xff, 0xff,
  232. 0xff, 0xff, 0xff, 0xff,
  233. 0xff, 0xff, 0xff, 0xff,
  234. },
  235. ),
  236. }
  237. }
  238. err = nil
  239. default:
  240. return
  241. }
  242. var length, _ int = net.IPMask.Size(ipNet.Mask)
  243. var address []byte = ip.To4()
  244. if address == nil {
  245. address = ip.To16()
  246. }
  247. if address == nil {
  248. panic(fmt.Sprintf("unexpected failure parsing '%s'\n", s))
  249. }
  250. na = &NetAddr{Address: address, Length: byte(length)}
  251. return
  252. }