diff --git a/cli/cmd/admin.go b/cli/cmd/admin.go new file mode 100644 index 000000000..153736624 --- /dev/null +++ b/cli/cmd/admin.go @@ -0,0 +1,83 @@ +package cmd + +import ( + "context" + "fmt" + "github.com/ente-io/cli/pkg/model" + "github.com/spf13/cobra" + "github.com/spf13/pflag" +) + +var _adminCmd = &cobra.Command{ + Use: "admin", + Short: "Commands for admin actions", + Long: "Commands for admin actions like disable or enabling 2fa, bumping up the storage limit, etc.", +} + +var _userDetailsCmd = &cobra.Command{ + Use: "get-user-id", + Short: "Get user id", + RunE: func(cmd *cobra.Command, args []string) error { + recoverWithLog() + var flags = &model.AdminActionForUser{} + cmd.Flags().VisitAll(func(f *pflag.Flag) { + if f.Name == "admin-user" { + flags.AdminEmail = f.Value.String() + } + if f.Name == "user" { + flags.UserEmail = f.Value.String() + } + }) + return ctrl.GetUserId(context.Background(), *flags) + }, +} + +var _disable2faCmd = &cobra.Command{ + Use: "disable-2fa", + Short: "Disable 2fa for a user", + RunE: func(cmd *cobra.Command, args []string) error { + recoverWithLog() + var flags = &model.AdminActionForUser{} + cmd.Flags().VisitAll(func(f *pflag.Flag) { + if f.Name == "admin-user" { + flags.AdminEmail = f.Value.String() + } + if f.Name == "user" { + flags.UserEmail = f.Value.String() + } + }) + fmt.Println("Not supported yet") + return nil + }, +} + +var _updateFreeUserStorage = &cobra.Command{ + Use: "update-subscription", + Short: "Update subscription for the free user", + RunE: func(cmd *cobra.Command, args []string) error { + recoverWithLog() + var flags = &model.AdminActionForUser{} + cmd.Flags().VisitAll(func(f *pflag.Flag) { + if f.Name == "admin-user" { + flags.AdminEmail = f.Value.String() + } + if f.Name == "user" { + flags.UserEmail = f.Value.String() + } + }) + return ctrl.UpdateFreeStorage(context.Background(), *flags) + }, +} + +func init() { + rootCmd.AddCommand(_adminCmd) + _ = _userDetailsCmd.MarkFlagRequired("admin-user") + _ = _userDetailsCmd.MarkFlagRequired("user") + _userDetailsCmd.Flags().StringP("admin-user", "a", "", "The email of the admin user. (required)") + _userDetailsCmd.Flags().StringP("user", "u", "", "The email of the user to fetch details for. (required)") + _disable2faCmd.Flags().StringP("admin-user", "a", "", "The email of the admin user. (required)") + _disable2faCmd.Flags().StringP("user", "u", "", "The email of the user to disable 2FA for. (required)") + _updateFreeUserStorage.Flags().StringP("admin-user", "a", "", "The email of the admin user. (required)") + _updateFreeUserStorage.Flags().StringP("user", "u", "", "The email of the user to update subscription for. (required)") + _adminCmd.AddCommand(_userDetailsCmd, _disable2faCmd, _updateFreeUserStorage) +} diff --git a/cli/internal/api/admin.go b/cli/internal/api/admin.go new file mode 100644 index 000000000..1c0b2c50a --- /dev/null +++ b/cli/internal/api/admin.go @@ -0,0 +1,57 @@ +package api + +import ( + "context" + "fmt" + "github.com/ente-io/cli/internal/api/models" + "time" +) + +func (c *Client) GetUserIdFromEmail(ctx context.Context, email string) (*models.UserDetails, error) { + var res models.UserDetails + r, err := c.restClient.R(). + SetContext(ctx). + SetResult(&res). + SetQueryParam("email", email). + Get("/admin/user/") + if err != nil { + return nil, err + } + if r.IsError() { + return nil, &ApiError{ + StatusCode: r.StatusCode(), + Message: r.String(), + } + } + return &res, nil +} +func (c *Client) UpdateFreePlanSub(ctx context.Context, userDetails *models.UserDetails, storageInBytes int64, expiryTimeInMicro int64) error { + var res interface{} + if userDetails.Subscription.ProductID != "free" { + return fmt.Errorf("user is not on free plan") + } + payload := map[string]interface{}{ + "userID": userDetails.User.ID, + "expiryTime": expiryTimeInMicro, + "transactionID": fmt.Sprintf("cli-on-%d", time.Now().Unix()), + "productID": "free", + "paymentProvider": "", + "storage": storageInBytes, + } + r, err := c.restClient.R(). + SetContext(ctx). + SetResult(&res). + SetBody(payload). + Put("/admin/user/subscription") + if err != nil { + return err + } + if r.IsError() { + return &ApiError{ + StatusCode: r.StatusCode(), + Message: r.String(), + } + } + return nil + +} diff --git a/cli/internal/api/models/user_details.go b/cli/internal/api/models/user_details.go new file mode 100644 index 000000000..259ff972b --- /dev/null +++ b/cli/internal/api/models/user_details.go @@ -0,0 +1,16 @@ +package models + +type UserDetails struct { + User struct { + ID int64 `json:"id"` + } `json:"user"` + Usage int64 `json:"usage"` + Email string `json:"email"` + + Subscription struct { + ExpiryTime int64 `json:"expiryTime"` + Storage int64 `json:"storage"` + ProductID string `json:"productID"` + PaymentProvider string `json:"paymentProvider"` + } `json:"subscription"` +} diff --git a/cli/internal/promt.go b/cli/internal/promt.go index 2e988ac24..fe2f61bde 100644 --- a/cli/internal/promt.go +++ b/cli/internal/promt.go @@ -5,11 +5,12 @@ import ( "errors" "fmt" "github.com/ente-io/cli/internal/api" + "golang.org/x/term" "log" "os" + "regexp" + "strconv" "strings" - - "golang.org/x/term" ) func GetSensitiveField(label string) (string, error) { @@ -81,6 +82,79 @@ func GetCode(promptText string, length int) (string, error) { } } +// parseStorageSize parses a string representing a storage size (e.g., "500MB", "2GB") into bytes. +func parseStorageSize(input string) (int64, error) { + units := map[string]int64{ + "MB": 1 << 20, + "GB": 1 << 30, + "TB": 1 << 40, + } + re := regexp.MustCompile(`(?i)^(\d+(?:\.\d+)?)(MB|GB|TB)$`) + matches := re.FindStringSubmatch(input) + + if matches == nil { + return 0, errors.New("invalid format") + } + + number, err := strconv.ParseFloat(matches[1], 64) + if err != nil { + return 0, fmt.Errorf("invalid number: %s", matches[1]) + } + + unit := strings.ToUpper(matches[2]) + bytes := int64(number * float64(units[unit])) + + return bytes, nil +} + +func ConfirmAction(promptText string) (bool, error) { + for { + input, err := GetUserInput(promptText) + if err != nil { + return false, err + } + if input == "" { + log.Fatal("No input entered") + return false, errors.New("invalid input. Please enter 'y' or 'n'") + } + if input == "c" { + return false, errors.New("cancelled") + } + if input == "y" { + return true, nil + } + if input == "n" { + return false, nil + } + fmt.Println("Invalid input. Please enter 'y' or 'n'.") + } +} + +// GetStorageSize prompts the user for a storage size and returns the size in bytes. +func GetStorageSize(promptText string) (int64, error) { + for { + input, err := GetUserInput(promptText) + if err != nil { + return 0, err + } + if input == "" { + log.Fatal("No storage size entered") + return 0, errors.New("no storage size entered") + } + if input == "c" { + return 0, errors.New("storage size entry cancelled") + } + + bytes, err := parseStorageSize(input) + if err != nil { + fmt.Println("Invalid storage size format. Please use a valid format like '500MB', '2GB'.") + continue + } + + return bytes, nil + } +} + func GetExportDir() string { for { exportDir, err := GetUserInput("Enter export directory") diff --git a/cli/main.go b/cli/main.go index 957669f9d..f9ad0fa06 100644 --- a/cli/main.go +++ b/cli/main.go @@ -86,7 +86,6 @@ func initConfig(cliConfigPath string) { viper.SetDefault("log.http", false) if err := viper.ReadInConfig(); err != nil { if _, ok := err.(viper.ConfigFileNotFoundError); ok { - log.Printf("Config file not found; using defaults %s", cliConfigPath) } else { // Config file was found but another error was produced } diff --git a/cli/pkg/admin_actions.go b/cli/pkg/admin_actions.go new file mode 100644 index 000000000..74152e5c5 --- /dev/null +++ b/cli/pkg/admin_actions.go @@ -0,0 +1,104 @@ +package pkg + +import ( + "context" + "fmt" + "github.com/ente-io/cli/internal" + "github.com/ente-io/cli/pkg/model" + "github.com/ente-io/cli/utils" + "log" + "strings" + "time" +) + +func (c *ClICtrl) GetUserId(ctx context.Context, params model.AdminActionForUser) error { + accountCtx, err := c.buildAdminContext(ctx, params.AdminEmail) + if err != nil { + return err + } + id, err := c.Client.GetUserIdFromEmail(accountCtx, params.UserEmail) + if err != nil { + return err + } + fmt.Println(id.User.ID) + return nil +} + +func (c *ClICtrl) UpdateFreeStorage(ctx context.Context, params model.AdminActionForUser) error { + accountCtx, err := c.buildAdminContext(ctx, params.AdminEmail) + if err != nil { + return err + } + userDetails, err := c.Client.GetUserIdFromEmail(accountCtx, params.UserEmail) + if err != nil { + return err + } + storageSize, err := internal.GetStorageSize("Enter a storage size (e.g.'5MB', '10GB', '2Tb'): ") + if err != nil { + log.Fatalf("Error: %v", err) + } + dateStr, err := internal.GetUserInput("Enter sub expiry date in YYYY-MM-DD format (e.g.'2040-12-31')") + if err != nil { + log.Fatalf("Error: %v", err) + } + date, err := _parseDateOrDateTime(dateStr) + if err != nil { + return err + } + + fmt.Printf("Updating storage for user %s to %s (old %s) with new expirty %s (old %s) \n", + params.UserEmail, + utils.ByteCountDecimalGIB(storageSize), utils.ByteCountDecimalGIB(userDetails.Subscription.Storage), + date.Format("2006-01-02"), + time.UnixMicro(userDetails.Subscription.ExpiryTime).Format("2006-01-02")) + // press y to confirm + confirmed, _ := internal.ConfirmAction("Are you sure you want to update the storage ('y' or 'n')?") + if !confirmed { + return nil + } else { + err := c.Client.UpdateFreePlanSub(accountCtx, userDetails, storageSize, date.UnixMicro()) + if err != nil { + return err + } else { + fmt.Println("Successfully updated storage and expiry date for user") + } + } + + return nil +} + +func (c *ClICtrl) buildAdminContext(ctx context.Context, adminEmail string) (context.Context, error) { + accounts, err := c.GetAccounts(ctx) + if err != nil { + return nil, err + } + var acc *model.Account + for _, a := range accounts { + if a.Email == adminEmail { + acc = &a + break + } + } + if acc == nil { + return nil, fmt.Errorf("account not found for %s, use `account list` to list accounts", adminEmail) + } + secretInfo, err := c.KeyHolder.LoadSecrets(*acc) + if err != nil { + return nil, err + } + accountCtx := c.buildRequestContext(ctx, *acc) + c.Client.AddToken(acc.AccountKey(), secretInfo.TokenStr()) + return accountCtx, nil +} + +func _parseDateOrDateTime(input string) (time.Time, error) { + var layout string + if strings.Contains(input, " ") { + // If the input contains a space, assume it's a date-time format + layout = "2006-01-02 15:04:05" + } else { + // If there's no space, assume it's just a date + layout = "2006-01-02" + } + return time.Parse(layout, input) +} diff --git a/cli/pkg/model/admin.go b/cli/pkg/model/admin.go new file mode 100644 index 000000000..e4d14d75c --- /dev/null +++ b/cli/pkg/model/admin.go @@ -0,0 +1,6 @@ +package model + +type AdminActionForUser struct { + UserEmail string + AdminEmail string +} diff --git a/cli/utils/convert.go b/cli/utils/convert.go new file mode 100644 index 000000000..5bd7c8026 --- /dev/null +++ b/cli/utils/convert.go @@ -0,0 +1,31 @@ +package utils + +import ( + "fmt" +) + +func ByteCountDecimal(b int64) string { + const unit = 1000 + if b < unit { + return fmt.Sprintf("%d B", b) + } + div, exp := int64(unit), 0 + for n := b / unit; n >= unit; n /= unit { + div *= unit + exp++ + } + return fmt.Sprintf("%.1f %cB", float64(b)/float64(div), "kMGTPE"[exp]) +} + +func ByteCountDecimalGIB(b int64) string { + const unit = 1024 + if b < unit { + return fmt.Sprintf("%d B", b) + } + div, exp := int64(unit), 0 + for n := b / unit; n >= unit; n /= unit { + div *= unit + exp++ + } + return fmt.Sprintf("%.1f %cB", float64(b)/float64(div), "kMGTPE"[exp]) +} diff --git a/cli/utils/time.go b/cli/utils/time.go index a86a9fc9f..fcb61f842 100644 --- a/cli/utils/time.go +++ b/cli/utils/time.go @@ -1,7 +1,6 @@ package utils import ( - "fmt" "log" "time" ) @@ -10,16 +9,3 @@ func TimeTrack(start time.Time, name string) { elapsed := time.Since(start) log.Printf("%s took %s", name, elapsed) } - -func ByteCountDecimal(b int64) string { - const unit = 1000 - if b < unit { - return fmt.Sprintf("%d B", b) - } - div, exp := int64(unit), 0 - for n := b / unit; n >= unit; n /= unit { - div *= unit - exp++ - } - return fmt.Sprintf("%.1f %cB", float64(b)/float64(div), "kMGTPE"[exp]) -}