/* * This Source Code Form is subject to the terms of the Mozilla Public * License, v. 2.0. If a copy of the MPL was not distributed with this * file, You can obtain one at https://mozilla.org/MPL/2.0/. */ package main import ( "context" "database/sql" "encoding/csv" "errors" "flag" "fmt" "io" "os" "os/signal" "path/filepath" "regexp" "strings" "time" "github.com/mattn/go-sqlite3" _ "github.com/mattn/go-sqlite3" logger "idio.link/go/logger/v3" "golang.org/x/term" ) var ( Version, Build string WithTimestamp bool ) const ( MaxRecords = 1_000_000_000 SQLITE_CONSTRAINT_PRIMARYKEY = 1555 ) func main() { log := logger.NewLogger() flag.BoolVar( &WithTimestamp, "timestamp", false, "Add a column 'timestamp' to imported rows. "+ "All rows receive the same value for 'timestamp', "+ "calculated at the start of the import. "+ "", ) flag.Parse() if len(flag.Args()) != 1 { if Build == "" { Version = "dev" Build = time.Now().UTC().Format(time.RFC3339) } fmt.Printf("%s %s\n", Version, Build) fmt.Printf("usage: %s [--timestamp] \n", os.Args[0]) os.Exit(1) } csvPath, err := filepath.Abs(flag.Args()[0]) if err != nil { log.Fatal("resolve absolute path for '%s': %v", flag.Args()[0], err) } ctx := context.WithValue(context.Background(), "log", log) ctx, cancel := context.WithCancel(ctx) go func() { _, stop := signal.NotifyContext( ctx, os.Interrupt, os.Kill, ) sig := make(chan os.Signal, 1) <-sig stop() log.Info("caught signal; shutting down") cancel() }() // open csv for read f, err := os.Open(csvPath) if err != nil { log.Fatal("open csv: %v", err) } defer func() { if err := f.Close(); err != nil { log.Error("close csv '%s': %v", csvPath, err) } }() r := csv.NewReader(f) r.ReuseRecord = true // open db tmp := strings.TrimSuffix(csvPath, ".csv") dbPath := tmp + ".db" name := scrubName(filepath.Base(tmp)) // no need to rollback or be crash resistant params := "?_synchronous=0&_journal_mode=OFF&_temp_store=2" db, err := sql.Open("sqlite3", dbPath+params) if err != nil { log.Fatal("open db '%s': %v", dbPath, err) } defer func() { if err := db.Close(); err != nil { log.Error("close db '%s': %v", dbPath, err) } }() // Read headers rec, err := r.Read() if err != nil && err != io.EOF { log.Fatal("read csv '%s': %v", csvPath, err) } // Prepare args with timestamp, as necessary. argsLen := len(rec) if WithTimestamp { argsLen++ } args := make([]any, argsLen) if WithTimestamp { args[0] = any(FormatRFC3339Subsec(time.Now().UTC())) } // create table, prepare insert var insert *sql.Stmt headers := make([]string, 0, argsLen) if WithTimestamp { headers = append(headers, "timestamp") } for _, f := range rec { f = scrubName(f) headers = append(headers, f) } _, err = createTable(ctx, db, name, headers) if err != nil { log.Fatal("create table '%s': %v", name, err) } insert, err = genInsert(ctx, db, name, headers) if err != nil { log.Fatal("prepare insert: %v", err) } defer func() { if err := insert.Close(); err != nil { log.Error("close prepared insert: %v", err) } }() // insert records tx, err := db.BeginTx(ctx, nil) if err != nil { log.Fatal("begin transaction: %v", err) } nonUnique := 0 i := 0 for { i++ if i >= MaxRecords { panic(fmt.Sprintf("BUG: max record limit reached: %d", MaxRecords)) } if ctx.Err() != nil { break } // Print count for every thousand-ish records if i&4095 == 0 { printStatus(i) } if i&65535 == 0 { if err := tx.Commit(); err != nil { log.Fatal("commit transaction: %v", err) } tx, err = db.BeginTx(ctx, nil) if err != nil { log.Fatal("begin transaction: %v", err) } } rec, err := r.Read() if err != nil { fmt.Println() if err == io.EOF { fmt.Fprintf(os.Stderr, "read %d records\n", i) fmt.Fprintf(os.Stderr, "skipped %d redundant records\n", nonUnique) } else { log.Error("read csv '%s': %v", csvPath, err) } break } // Rather slow, doing this for every record, but since // we cannot know how many fields will exist ahead of // time (we compile *now*, not after we've read the // headers), we cannot simply enter, say, 27 arguments, // each converted to `any`, by hand. It may be faster to // simply print the SQL statements and pipe them to // sqlite. Should check this. recArgs := args[:] if WithTimestamp { // Skip over timestamp recArgs = args[1:] } for j := 0; j < len(recArgs); j++ { recArgs[j] = any(rec[j]) } // For explanation of extended codes, see https://www.sqlite.org/rescode.html _, err = insert.ExecContext(ctx, args...) if err != nil { var sqliteErr sqlite3.Error if errors.As(err, &sqliteErr) && sqliteErr.ExtendedCode == SQLITE_CONSTRAINT_PRIMARYKEY { nonUnique++ log.Debug("insert record '%#v': %v", rec, err) } else { log.Fatal("insert record '%#v': %v", rec, err) } } } if err := tx.Commit(); err != nil { log.Fatal("commit transaction: %v", err) } } func printStatus(n int) { if term.IsTerminal(int(os.Stdout.Fd())) { fmt.Printf("\x1b[%dG", 1) fmt.Printf("%d", n) } } func genInsert( ctx context.Context, db *sql.DB, name string, headers []string, ) (*sql.Stmt, error) { log := ctx.Value("log").(*logger.Logger) var b strings.Builder b.WriteString(fmt.Sprintf("INSERT INTO %s (", name)) for i, h := range headers { if i != 0 { b.WriteString(", ") } b.WriteString(h) } b.WriteString(") VALUES (") for i := 0; i < len(headers); i++ { if i != 0 { b.WriteString(", ") } b.WriteString(fmt.Sprintf("$%d", i+1)) } b.WriteString(");") log.Debug("debug: prepare insert: %s\n", b.String()) return db.PrepareContext(ctx, b.String()) } func createTable( ctx context.Context, db *sql.DB, name string, headers []string, ) (sql.Result, error) { var b strings.Builder b.WriteString( fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s (", name), ) for i, h := range headers { if i != 0 { b.WriteString(", ") } b.WriteString(fmt.Sprintf("%s text", h)) } b.WriteString(");") fmt.Fprintf(os.Stderr, "%s\n", b.String()) return db.ExecContext(ctx, b.String()) } var badRunes = regexp.MustCompile(`[^a-zA-Z0-9_\-\.]+`) var badFirst = regexp.MustCompile(`^[^a-zA-Z]+`) var parens = regexp.MustCompile(`[\(\)]`) var spaces = regexp.MustCompile(`\s+`) var hyphens = regexp.MustCompile(`\-+`) var unders = regexp.MustCompile(`_+`) func scrubName(s string) string { s = strings.ToLower(s) s = badFirst.ReplaceAllLiteralString(s, "") s = parens.ReplaceAllLiteralString(s, "") s = strings.TrimSpace(s) s = spaces.ReplaceAllLiteralString(s, "_") s = badRunes.ReplaceAllLiteralString(s, "_") s = hyphens.ReplaceAllLiteralString(s, "_") s = unders.ReplaceAllLiteralString(s, "_") s = strings.TrimSuffix(s, "_") return fmt.Sprintf("'%s'", s) } func FormatRFC3339Subsec(t time.Time) string { // RFC3339Nano = "2006-01-02T15:04:05.999999999Z07:00" rfc3339 := t.Format(time.RFC3339Nano) datetime := rfc3339[:19] nano := rfc3339[20:23] // Just ms return fmt.Sprintf( "%s.%sZ", datetime, nano, ) }