extrapolation.go 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332
  1. // SPDX-License-Identifier: Unlicense OR MIT
  2. package fling
  3. import (
  4. "math"
  5. "strconv"
  6. "strings"
  7. "time"
  8. )
  9. // Extrapolation computes a 1-dimensional velocity estimate
  10. // for a set of timestamped points using the least squares
  11. // fit of a 2nd order polynomial. The same method is used
  12. // by Android.
  13. type Extrapolation struct {
  14. // Index into points.
  15. idx int
  16. // Circular buffer of samples.
  17. samples []sample
  18. lastValue float32
  19. // Pre-allocated cache for samples.
  20. cache [historySize]sample
  21. // Filtered values and times
  22. values [historySize]float32
  23. times [historySize]float32
  24. }
  25. type sample struct {
  26. t time.Duration
  27. v float32
  28. }
  29. type matrix struct {
  30. rows, cols int
  31. data []float32
  32. }
  33. type Estimate struct {
  34. Velocity float32
  35. Distance float32
  36. }
  37. type coefficients [degree + 1]float32
  38. const (
  39. degree = 2
  40. historySize = 20
  41. maxAge = 100 * time.Millisecond
  42. maxSampleGap = 40 * time.Millisecond
  43. )
  44. // SampleDelta adds a relative sample to the estimation.
  45. func (e *Extrapolation) SampleDelta(t time.Duration, delta float32) {
  46. val := delta + e.lastValue
  47. e.Sample(t, val)
  48. }
  49. // Sample adds an absolute sample to the estimation.
  50. func (e *Extrapolation) Sample(t time.Duration, val float32) {
  51. e.lastValue = val
  52. if e.samples == nil {
  53. e.samples = e.cache[:0]
  54. }
  55. s := sample{
  56. t: t,
  57. v: val,
  58. }
  59. if e.idx == len(e.samples) && e.idx < cap(e.samples) {
  60. e.samples = append(e.samples, s)
  61. } else {
  62. e.samples[e.idx] = s
  63. }
  64. e.idx++
  65. if e.idx == cap(e.samples) {
  66. e.idx = 0
  67. }
  68. }
  69. // Velocity returns an estimate of the implied velocity and
  70. // distance for the points sampled, or zero if the estimation method
  71. // failed.
  72. func (e *Extrapolation) Estimate() Estimate {
  73. if len(e.samples) == 0 {
  74. return Estimate{}
  75. }
  76. values := e.values[:0]
  77. times := e.times[:0]
  78. first := e.get(0)
  79. t := first.t
  80. // Walk backwards collecting samples.
  81. for i := 0; i < len(e.samples); i++ {
  82. p := e.get(-i)
  83. age := first.t - p.t
  84. if age >= maxAge || t-p.t >= maxSampleGap {
  85. // If the samples are too old or
  86. // too much time passed between samples
  87. // assume they're not part of the fling.
  88. break
  89. }
  90. t = p.t
  91. values = append(values, first.v-p.v)
  92. times = append(times, float32((-age).Seconds()))
  93. }
  94. coef, ok := polyFit(times, values)
  95. if !ok {
  96. return Estimate{}
  97. }
  98. dist := values[len(values)-1] - values[0]
  99. return Estimate{
  100. Velocity: coef[1],
  101. Distance: dist,
  102. }
  103. }
  104. func (e *Extrapolation) get(i int) sample {
  105. idx := (e.idx + i - 1 + len(e.samples)) % len(e.samples)
  106. return e.samples[idx]
  107. }
  108. // fit computes the least squares polynomial fit for
  109. // the set of points in X, Y. If the fitting fails
  110. // because of contradicting or insufficient data,
  111. // fit returns false.
  112. func polyFit(X, Y []float32) (coefficients, bool) {
  113. if len(X) != len(Y) {
  114. panic("X and Y lengths differ")
  115. }
  116. if len(X) <= degree {
  117. // Not enough points to fit a curve.
  118. return coefficients{}, false
  119. }
  120. // Use a method similar to Android's VelocityTracker.cpp:
  121. // https://android.googlesource.com/platform/frameworks/base/+/56a2301/libs/androidfw/VelocityTracker.cpp
  122. // where all weights are 1.
  123. // First, expand the X vector to the matrix A in column-major order.
  124. A := newMatrix(degree+1, len(X))
  125. for i, x := range X {
  126. A.set(0, i, 1)
  127. for j := 1; j < A.rows; j++ {
  128. A.set(j, i, A.get(j-1, i)*x)
  129. }
  130. }
  131. Q, Rt, ok := decomposeQR(A)
  132. if !ok {
  133. return coefficients{}, false
  134. }
  135. // Solve R*B = Qt*Y for B, which is then the polynomial coefficients.
  136. // Since R is upper triangular, we can proceed from bottom right to
  137. // upper left.
  138. // https://en.wikipedia.org/wiki/Non-linear_least_squares
  139. var B coefficients
  140. for i := Q.rows - 1; i >= 0; i-- {
  141. B[i] = dot(Q.col(i), Y)
  142. for j := Q.rows - 1; j > i; j-- {
  143. B[i] -= Rt.get(i, j) * B[j]
  144. }
  145. B[i] /= Rt.get(i, i)
  146. }
  147. return B, true
  148. }
  149. // decomposeQR computes and returns Q, Rt where Q*transpose(Rt) = A, if
  150. // possible. R is guaranteed to be upper triangular and only the square
  151. // part of Rt is returned.
  152. func decomposeQR(A *matrix) (*matrix, *matrix, bool) {
  153. // Gram-Schmidt QR decompose A where Q*R = A.
  154. // https://en.wikipedia.org/wiki/Gram%E2%80%93Schmidt_process
  155. Q := newMatrix(A.rows, A.cols) // Column-major.
  156. Rt := newMatrix(A.rows, A.rows) // R transposed, row-major.
  157. for i := 0; i < Q.rows; i++ {
  158. // Copy A column.
  159. for j := 0; j < Q.cols; j++ {
  160. Q.set(i, j, A.get(i, j))
  161. }
  162. // Subtract projections. Note that int the projection
  163. //
  164. // proju a = <u, a>/<u, u> u
  165. //
  166. // the normalized column e replaces u, where <e, e> = 1:
  167. //
  168. // proje a = <e, a>/<e, e> e = <e, a> e
  169. for j := 0; j < i; j++ {
  170. d := dot(Q.col(j), Q.col(i))
  171. for k := 0; k < Q.cols; k++ {
  172. Q.set(i, k, Q.get(i, k)-d*Q.get(j, k))
  173. }
  174. }
  175. // Normalize Q columns.
  176. n := norm(Q.col(i))
  177. if n < 0.000001 {
  178. // Degenerate data, no solution.
  179. return nil, nil, false
  180. }
  181. invNorm := 1 / n
  182. for j := 0; j < Q.cols; j++ {
  183. Q.set(i, j, Q.get(i, j)*invNorm)
  184. }
  185. // Update Rt.
  186. for j := i; j < Rt.cols; j++ {
  187. Rt.set(i, j, dot(Q.col(i), A.col(j)))
  188. }
  189. }
  190. return Q, Rt, true
  191. }
  192. func norm(V []float32) float32 {
  193. var n float32
  194. for _, v := range V {
  195. n += v * v
  196. }
  197. return float32(math.Sqrt(float64(n)))
  198. }
  199. func dot(V1, V2 []float32) float32 {
  200. var d float32
  201. for i, v1 := range V1 {
  202. d += v1 * V2[i]
  203. }
  204. return d
  205. }
  206. func newMatrix(rows, cols int) *matrix {
  207. return &matrix{
  208. rows: rows,
  209. cols: cols,
  210. data: make([]float32, rows*cols),
  211. }
  212. }
  213. func (m *matrix) set(row, col int, v float32) {
  214. if row < 0 || row >= m.rows {
  215. panic("row out of range")
  216. }
  217. if col < 0 || col >= m.cols {
  218. panic("col out of range")
  219. }
  220. m.data[row*m.cols+col] = v
  221. }
  222. func (m *matrix) get(row, col int) float32 {
  223. if row < 0 || row >= m.rows {
  224. panic("row out of range")
  225. }
  226. if col < 0 || col >= m.cols {
  227. panic("col out of range")
  228. }
  229. return m.data[row*m.cols+col]
  230. }
  231. func (m *matrix) col(c int) []float32 {
  232. return m.data[c*m.cols : (c+1)*m.cols]
  233. }
  234. func (m *matrix) approxEqual(m2 *matrix) bool {
  235. if m.rows != m2.rows || m.cols != m2.cols {
  236. return false
  237. }
  238. const epsilon = 0.00001
  239. for row := 0; row < m.rows; row++ {
  240. for col := 0; col < m.cols; col++ {
  241. d := m2.get(row, col) - m.get(row, col)
  242. if d < -epsilon || d > epsilon {
  243. return false
  244. }
  245. }
  246. }
  247. return true
  248. }
  249. func (m *matrix) transpose() *matrix {
  250. t := &matrix{
  251. rows: m.cols,
  252. cols: m.rows,
  253. data: make([]float32, len(m.data)),
  254. }
  255. for i := 0; i < m.rows; i++ {
  256. for j := 0; j < m.cols; j++ {
  257. t.set(j, i, m.get(i, j))
  258. }
  259. }
  260. return t
  261. }
  262. func (m *matrix) mul(m2 *matrix) *matrix {
  263. if m.rows != m2.cols {
  264. panic("mismatched matrices")
  265. }
  266. mm := &matrix{
  267. rows: m.rows,
  268. cols: m2.cols,
  269. data: make([]float32, m.rows*m2.cols),
  270. }
  271. for i := 0; i < mm.rows; i++ {
  272. for j := 0; j < mm.cols; j++ {
  273. var v float32
  274. for k := 0; k < m.rows; k++ {
  275. v += m.get(k, j) * m2.get(i, k)
  276. }
  277. mm.set(i, j, v)
  278. }
  279. }
  280. return mm
  281. }
  282. func (m *matrix) String() string {
  283. var b strings.Builder
  284. for i := 0; i < m.rows; i++ {
  285. for j := 0; j < m.cols; j++ {
  286. v := m.get(i, j)
  287. b.WriteString(strconv.FormatFloat(float64(v), 'g', -1, 32))
  288. b.WriteString(", ")
  289. }
  290. b.WriteString("\n")
  291. }
  292. return b.String()
  293. }
  294. func (c coefficients) approxEqual(c2 coefficients) bool {
  295. const epsilon = 0.00001
  296. for i, v := range c {
  297. d := v - c2[i]
  298. if d < -epsilon || d > epsilon {
  299. return false
  300. }
  301. }
  302. return true
  303. }