package main import ( "context" "encoding/json" "flag" "fmt" "io" "io/fs" "os" "os/signal" "path/filepath" "runtime" "sync" "sync/atomic" "syscall" "github.com/sourcegraph/conc/pool" ) var ( sourceDir = flag.String("source", ".", "source data directory") tempDir = flag.String("temp", "/tmp", "temporary storage directory") dbPath = flag.String("db", "datashake.json", "database file path") minimumSize = flag.Int64("min-size", 1024*1024, "minimum size in bytes") concurrency = flag.Int("concurrency", 1, "concurrent processing limit") ) func main() { flag.Parse() ctx, stop := signal.NotifyContext( context.Background(), syscall.SIGINT, syscall.SIGTERM, ) defer stop() concurrency := *concurrency if concurrency < 1 { concurrency = runtime.GOMAXPROCS(0) } tasks = pool.New().WithMaxGoroutines(concurrency) if err := loadDb(); err != nil { fmt.Println("error", err) os.Exit(1) } go run() errors := errors actions := actions for { select { case err, ok := <-errors: if !ok { errors = nil break } db.Alert(err) case action, ok := <-actions: if !ok { actions = nil break } db.Record(action) case <-ctx.Done(): running.Store(false) } if errors == nil && actions == nil { break } } if err := saveDb(); err != nil { fmt.Println("error", err) os.Exit(1) } } var ( running atomic.Bool tasks *pool.Pool errors = make(chan Error) actions = make(chan Action) db = DB{ Seen: make(map[string]struct{}), } dbLock sync.Mutex ) // run drives the directory traversal. func run() { defer func() { close(errors) close(actions) }() running.Store(true) if err := filepath.WalkDir(*sourceDir, process); err != nil { fmt.Println("error:", err) } tasks.Wait() } // process is a visitor for `filepath.WalkDir` that performs the rebalancing // algorithm against regular files. // // This function normally never returns an error, since that would stop the // directory walk. Instead, any errors are sent to the `errors` channel. func process(path string, d fs.DirEntry, err error) (typicallyNil error) { if !running.Load() { return fs.SkipAll } if err != nil || d.IsDir() || !d.Type().IsRegular() { return } tasks.Go(func() { if running.Load() { work(path, d) } }) return } // work rebalances a single file. func work(path string, d fs.DirEntry) { var err error var srcFilePath = path var tempFilePath string reportErr := func(reported error) { e := Error{ Message: reported.Error(), FilePath: srcFilePath, TempPath: tempFilePath, } errors <- e } defer func() { if err != nil { reportErr(err) } }() srcFileName := d.Name() srcFilePath, err = filepath.Abs(path) if err != nil { return } if db.Knows(srcFilePath) { return } srcStat, err := os.Stat(srcFilePath) if err != nil { return } if srcStat.Size() < *minimumSize { return } tempDirPath, err := os.MkdirTemp(*tempDir, "*") if err != nil { return } tempFilePath = filepath.Join(tempDirPath, srcFileName) safeToRemoveTemp := true defer func() { if !safeToRemoveTemp { reportErr(missingFile) return } if err := os.RemoveAll(tempDirPath); err != nil { reportErr(err) } }() err = copy(srcFilePath, tempFilePath) if err != nil { return } safeToRemoveTemp = false err = os.Remove(srcFilePath) if err != nil { return } err = copy(tempFilePath, srcFilePath) if err != nil { return } safeToRemoveTemp = true db.Remember(srcFilePath) } var missingFile = fmt.Errorf("file may be missing in temp directory") // copy opens the file from the source path, then creates a copy of it at the // destination path. The mode, uid and gid bits from the source file are // replicated in the copy. func copy(srcPath, dstPath string) error { fmt.Println("copying", srcPath, "to", dstPath) srcFile, err := os.Open(srcPath) if err != nil { return err } defer func() { _ = srcFile.Close() }() dstFile, err := os.Create(dstPath) if err != nil { return err } defer func() { _ = dstFile.Close() }() srcStat, err := os.Stat(srcPath) if err != nil { return err } err = os.Chmod(dstPath, srcStat.Mode()) if err != nil { return err } if sysStat, ok := srcStat.Sys().(*syscall.Stat_t); ok { uid := int(sysStat.Uid) gid := int(sysStat.Gid) err = os.Chown(dstPath, uid, gid) if err != nil { return err } } _, err = io.Copy(dstFile, srcFile) if err != nil { return err } actions <- Action{ Source: srcPath, Destination: dstPath, } return nil } // DB stores information collected by the program. // // The database is loaded from a JSON file when the program starts and saved // back to that JSON file as the program finishes. type DB struct { // Seen is a set of all files which have been successfully re-balanced. Seen map[string]struct{} // Log stores every successful copy operation. Log []Action // Errors stores details on any error that occurs. Errors []Error } // Action details a file copy operation. type Action struct { Source string Destination string } // Error describes a problem from re-balancing a file. type Error struct { // Message contains a string for the underlying error's message. Message string // FilePath is the path of the file. FilePath string // TempPath is the temporary directory. // // This may be blank, depending on when the error occurred. TempPath string } func (db *DB) Knows(path string) bool { dbLock.Lock() defer dbLock.Unlock() _, ok := db.Seen[path] return ok } func (db *DB) Remember(path string) { dbLock.Lock() defer dbLock.Unlock() db.Seen[path] = struct{}{} } func (db *DB) Record(a Action) { dbLock.Lock() defer dbLock.Unlock() db.Log = append(db.Log, a) } func (db *DB) Alert(e Error) { dbLock.Lock() defer dbLock.Unlock() db.Errors = append(db.Errors, e) } func loadDb() error { if *dbPath == "" { return nil } dbFile, err := os.Open(*dbPath) if err != nil { return nil } defer func() { _ = dbFile.Close() }() d := json.NewDecoder(dbFile) err = d.Decode(&db) return err } func saveDb() error { if *dbPath == "" { return nil } dbFile, err := os.Create(*dbPath) if err != nil { return err } defer func() { _ = dbFile.Close() }() e := json.NewEncoder(dbFile) err = e.Encode(&db) return err }