package main import ( "context" "encoding/json" "flag" "fmt" "io" "io/fs" "os" "os/signal" "path/filepath" "runtime" "sync" "sync/atomic" "syscall" "github.com/sourcegraph/conc/pool" ) var ( sourceDir = flag.String("source", ".", "source data directory") tempDir = flag.String("temp", "/tmp", "temporary storage directory") dbPath = flag.String("db", "datashake.json", "database file path") minimumSize = flag.Int64("min-size", 1024*1024, "minimum size in bytes") concurrency = flag.Int("concurrency", 1, "concurrent processing limit") ) func main() { ctx, stop := signal.NotifyContext( context.Background(), syscall.SIGINT, syscall.SIGTERM, ) defer stop() flag.Parse() concurrency := *concurrency if concurrency < 1 { concurrency = runtime.GOMAXPROCS(0) } tasks = pool.New().WithMaxGoroutines(concurrency) if err := loadDb(); err != nil { fmt.Println("error", err) os.Exit(1) } running.Store(true) go func() { if err := filepath.WalkDir(*sourceDir, process); err != nil { errors <- err } pending.Wait() close(errors) }() Loop: for { select { case err, ok := <-errors: if ok { fmt.Println("error:", err) } else { break Loop } case <-ctx.Done(): running.Store(false) } } if err := saveDb(); err != nil { fmt.Println("error", err) os.Exit(1) } } var ( running atomic.Bool tasks *pool.Pool pending sync.WaitGroup errors = make(chan error) db = DB{ Processed: make(map[string]struct{}), } dbLock sync.Mutex ) // process is a visitor for `filepath.WalkDir` that performs the rebalancing // algorithm against regular files. // // This function normally never returns an error, since that would stop the // directory walk. Instead, any errors are sent to the `errors` channel. func process(path string, d fs.DirEntry, err error) (typicallyNil error) { if !running.Load() { return fs.SkipAll } if err != nil || d.IsDir() || !d.Type().IsRegular() { return } pending.Add(1) tasks.Go(func() { defer pending.Done() if running.Load() { work(path, d) } }) return } // work rebalances a single file. func work(path string, d fs.DirEntry) { var err error defer func() { if err != nil { errors <- err } }() srcFileName := d.Name() srcFilePath, err := filepath.Abs(path) if err != nil { return } if db.Contains(srcFilePath) { return } srcStat, err := os.Stat(srcFilePath) if err != nil { return } if srcStat.Size() < *minimumSize { return } tempDirPath, err := os.MkdirTemp(*tempDir, "*") if err != nil { return } tempFilePath := filepath.Join(tempDirPath, srcFileName) safeToRemoveTemp := true defer func() { if !safeToRemoveTemp { err := fmt.Errorf( "%s may be lost in %s", srcFilePath, tempDirPath, ) errors <- err return } if err := os.RemoveAll(tempDirPath); err != nil { errors <- err } }() err = copy(srcFilePath, tempFilePath) if err != nil { return } safeToRemoveTemp = false err = os.Remove(srcFilePath) if err != nil { return } err = copy(tempFilePath, srcFilePath) if err != nil { return } safeToRemoveTemp = true db.Remember(srcFilePath) } // copy opens the file from the source path, then creates a copy of it at the // destination path. The mode, uid and gid bits from the source file are // replicated in the copy. func copy(srcPath, dstPath string) error { fmt.Println("copying", srcPath, "to", dstPath) srcFile, err := os.Open(srcPath) if err != nil { return err } defer func() { _ = srcFile.Close() }() dstFile, err := os.Create(dstPath) if err != nil { return err } defer func() { _ = dstFile.Close() }() srcStat, err := os.Stat(srcPath) if err != nil { return err } err = os.Chmod(dstPath, srcStat.Mode()) if err != nil { return err } if sysStat, ok := srcStat.Sys().(*syscall.Stat_t); ok { uid := int(sysStat.Uid) gid := int(sysStat.Gid) err = os.Chown(dstPath, uid, gid) if err != nil { return err } } _, err = io.Copy(dstFile, srcFile) return err } // DB holds a set of files which have been rebalanced. // // These files are skipped on future runs of the program. // // The database is loaded from a JSON file when the program starts and saved // back to that JSON file as the program finishes. type DB struct { Processed map[string]struct{} } func (db *DB) Contains(path string) bool { dbLock.Lock() defer dbLock.Unlock() _, ok := db.Processed[path] return ok } func (db *DB) Remember(path string) { dbLock.Lock() defer dbLock.Unlock() db.Processed[path] = struct{}{} } func loadDb() error { if *dbPath == "" { return nil } dbFile, err := os.Open(*dbPath) if err != nil { return nil } defer func() { _ = dbFile.Close() }() d := json.NewDecoder(dbFile) err = d.Decode(&db) return err } func saveDb() error { if *dbPath == "" { return nil } dbFile, err := os.Create(*dbPath) if err != nil { return err } defer func() { _ = dbFile.Close() }() e := json.NewEncoder(dbFile) err = e.Encode(&db) return err }