diff --git a/.vscode/settings.json b/.vscode/settings.json index 907e121..58b46dc 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -1,5 +1,8 @@ { "cSpell.words": [ - "datashake" + "datashake", + "rebalances", + "rebalancing", + "sourcegraph" ] } diff --git a/datashake.go b/datashake.go index 95bc91c..f6f67f1 100644 --- a/datashake.go +++ b/datashake.go @@ -1,6 +1,7 @@ package main import ( + "context" "encoding/json" "flag" "fmt" @@ -9,13 +10,37 @@ import ( "os" "os/signal" "path/filepath" + "runtime" + "sync" "sync/atomic" "syscall" + + "github.com/sourcegraph/conc/pool" +) + +var ( + sourceDir = flag.String("source", ".", "source data directory") + tempDir = flag.String("temp", "/tmp", "temporary storage directory") + dbPath = flag.String("db", "datashake.json", "database file path") + minimumSize = flag.Int64("min-size", 1024*1024, "minimum size in bytes") + concurrency = flag.Int("concurrency", 1, "concurrent processing limit") ) func main() { + ctx, stop := signal.NotifyContext( + context.Background(), + syscall.SIGINT, syscall.SIGTERM, + ) + defer stop() + flag.Parse() + concurrency := *concurrency + if concurrency < 1 { + concurrency = runtime.GOMAXPROCS(0) + } + tasks = pool.New().WithMaxGoroutines(concurrency) + if err := loadDb(); err != nil { fmt.Println("error", err) os.Exit(1) @@ -26,12 +51,10 @@ func main() { if err := filepath.WalkDir(*sourceDir, process); err != nil { errors <- err } + pending.Wait() close(errors) }() - signals := make(chan os.Signal, 1) - signal.Notify(signals, syscall.SIGINT, syscall.SIGTERM) - Loop: for { select { @@ -41,7 +64,7 @@ Loop: } else { break Loop } - case <-signals: + case <-ctx.Done(): running.Store(false) } } @@ -52,69 +75,50 @@ Loop: } } -var sourceDir = flag.String("source", "", "source data directory") -var tempDir = flag.String("temp", "", "temporary storage directory") -var dbPath = flag.String("db", "datashake.json", "database file path") -var minimumSize = flag.Int64("min-size", 1024*1024, "minimum size in bytes") +var ( + running atomic.Bool -var errors = make(chan error) -var db = DB{ - Processed: make(map[string]struct{}), -} -var running atomic.Bool + tasks *pool.Pool + pending sync.WaitGroup + errors = make(chan error) -func loadDb() error { - if *dbPath == "" { - return nil + db = DB{ + Processed: make(map[string]struct{}), } - dbFile, err := os.Open(*dbPath) - if err != nil { - return nil - } - defer func() { - _ = dbFile.Close() - }() - d := json.NewDecoder(dbFile) - err = d.Decode(&db) - return err -} + dbLock sync.Mutex +) -func saveDb() error { - if *dbPath == "" { - return nil - } - dbFile, err := os.Create(*dbPath) - if err != nil { - return err - } - defer func() { - _ = dbFile.Close() - }() - e := json.NewEncoder(dbFile) - err = e.Encode(&db) - return err -} - -// process is a visitor for `filepath.WalkDir` that implements the rebalancing -// algorithm. +// process is a visitor for `filepath.WalkDir` that performs the rebalancing +// algorithm against regular files. // -// This function never returns an error, since that would stop the directory +// This function normally never returns an error, since that would stop the // directory walk. Instead, any errors are sent to the `errors` channel. -func process(path string, d fs.DirEntry, err error) (alwaysNil error) { +func process(path string, d fs.DirEntry, err error) (typicallyNil error) { if !running.Load() { return fs.SkipAll } + if err != nil || d.IsDir() || !d.Type().IsRegular() { + return + } + pending.Add(1) + tasks.Go(func() { + defer pending.Done() + if running.Load() { + work(path, d) + } + }) + return +} +// work rebalances a single file. +func work(path string, d fs.DirEntry) { + var err error defer func() { if err != nil { errors <- err } }() - if err != nil || d.IsDir() { - return - } - srcFileName := d.Name() srcFilePath, err := filepath.Abs(path) if err != nil { @@ -168,8 +172,6 @@ func process(path string, d fs.DirEntry, err error) (alwaysNil error) { } safeToRemoveTemp = true db.Remember(srcFilePath) - - return } // copy opens the file from the source path, then creates a copy of it at the @@ -214,15 +216,57 @@ func copy(srcPath, dstPath string) error { return err } +// DB holds a set of files which have been rebalanced. +// +// These files are skipped on future runs of the program. +// +// The database is loaded from a JSON file when the program starts and saved +// back to that JSON file as the program finishes. type DB struct { Processed map[string]struct{} } func (db *DB) Contains(path string) bool { + dbLock.Lock() + defer dbLock.Unlock() _, ok := db.Processed[path] return ok } func (db *DB) Remember(path string) { + dbLock.Lock() + defer dbLock.Unlock() db.Processed[path] = struct{}{} } + +func loadDb() error { + if *dbPath == "" { + return nil + } + dbFile, err := os.Open(*dbPath) + if err != nil { + return nil + } + defer func() { + _ = dbFile.Close() + }() + d := json.NewDecoder(dbFile) + err = d.Decode(&db) + return err +} + +func saveDb() error { + if *dbPath == "" { + return nil + } + dbFile, err := os.Create(*dbPath) + if err != nil { + return err + } + defer func() { + _ = dbFile.Close() + }() + e := json.NewEncoder(dbFile) + err = e.Encode(&db) + return err +} diff --git a/go.mod b/go.mod index 333f7ed..daedd34 100644 --- a/go.mod +++ b/go.mod @@ -1,3 +1,5 @@ module gogs.humancabbage.net/datashake go 1.21 + +require github.com/sourcegraph/conc v0.3.0 // indirect diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..29e2b58 --- /dev/null +++ b/go.sum @@ -0,0 +1,2 @@ +github.com/sourcegraph/conc v0.3.0 h1:OQTbbt6P72L20UqAkXXuLOj79LfEanQ+YQFNpLA9ySo= +github.com/sourcegraph/conc v0.3.0/go.mod h1:Sdozi7LEKbFPqYX2/J+iBAM6HpqSLTASQIKqDmF7Mt0=