Add some requesters
This commit is contained in:
130
internal/service/password.go
Normal file
130
internal/service/password.go
Normal file
@@ -0,0 +1,130 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/subtle"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"hash"
|
||||
"log"
|
||||
"strings"
|
||||
|
||||
"golang.org/x/crypto/argon2"
|
||||
)
|
||||
|
||||
func getDefaultArgon2Params() *Argon2Params {
|
||||
config, err := GetConfig()
|
||||
if err != nil {
|
||||
log.Fatalf("[Error] Failed to load config: %v", err)
|
||||
}
|
||||
params := &Argon2Params{
|
||||
Memory: config.Password.Memory * 1024,
|
||||
Iterations: config.Password.Iterations,
|
||||
Parallelism: config.Password.Parallelism,
|
||||
SaltLength: config.Password.SaltLength,
|
||||
KeyLength: config.Password.KeyLength,
|
||||
}
|
||||
return params
|
||||
}
|
||||
|
||||
// HashBytes returns the hex string of the given hasher after writing data.
|
||||
func HashBytes(h hash.Hash, data []byte) string {
|
||||
h.Write(data)
|
||||
return fmt.Sprintf("%x", h.Sum(nil))
|
||||
}
|
||||
|
||||
// HashPassword uses Argon2id to securely hash the original password
|
||||
func HashPassword(password string) (string, error) {
|
||||
params := getDefaultArgon2Params()
|
||||
|
||||
// Generate random salt
|
||||
salt := make([]byte, params.SaltLength)
|
||||
if _, err := rand.Read(salt); err != nil {
|
||||
return "", fmt.Errorf("failed to generate salt: %w", err)
|
||||
}
|
||||
|
||||
// Use Argon2id to compute hash
|
||||
hash := argon2.IDKey(
|
||||
[]byte(password),
|
||||
salt,
|
||||
params.Iterations,
|
||||
params.Memory,
|
||||
params.Parallelism,
|
||||
params.KeyLength,
|
||||
)
|
||||
|
||||
// Encode as string to store (compatible with PHC format)
|
||||
b64Salt := base64.RawStdEncoding.EncodeToString(salt)
|
||||
b64Hash := base64.RawStdEncoding.EncodeToString(hash)
|
||||
|
||||
// Format: $argon2id$v=19$m=65536,t=3,p=2$c2FsdA$hash...
|
||||
return fmt.Sprintf("$argon2id$v=%d$m=%d,t=%d,p=%d$%s$%s",
|
||||
argon2.Version,
|
||||
params.Memory,
|
||||
params.Iterations,
|
||||
params.Parallelism,
|
||||
b64Salt,
|
||||
b64Hash,
|
||||
), nil
|
||||
}
|
||||
|
||||
// CheckPassword Verify if the original password matches the Argon2id hash
|
||||
func CheckPassword(password, hash string) bool {
|
||||
ok, err := verifyPassword(password, hash)
|
||||
return err == nil && ok
|
||||
}
|
||||
|
||||
// verifyPassword It is an internal validation function that returns Boolean values and errors
|
||||
func verifyPassword(password, hash string) (bool, error) {
|
||||
parts := parseHash(hash)
|
||||
if parts == nil {
|
||||
return false, errors.New("invalid hash format")
|
||||
}
|
||||
|
||||
salt, err := base64.RawStdEncoding.DecodeString(parts.SaltBase64)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
expectedHash, err := base64.RawStdEncoding.DecodeString(parts.HashBase64)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
// Recalculate Hash
|
||||
computed := argon2.IDKey(
|
||||
[]byte(password),
|
||||
salt,
|
||||
parts.Iterations,
|
||||
parts.Memory,
|
||||
parts.Parallelism,
|
||||
uint32(len(expectedHash)),
|
||||
)
|
||||
|
||||
// Constant time comparison to prevent timing attacks
|
||||
return subtle.ConstantTimeCompare(computed, expectedHash) == 1, nil
|
||||
}
|
||||
|
||||
// parseHash Parse the $argon2id$... format string
|
||||
func parseHash(hash string) *HashParts {
|
||||
vals := strings.Split(hash, "$")
|
||||
if len(vals) != 6 {
|
||||
return nil
|
||||
}
|
||||
|
||||
var m, t uint32
|
||||
var p uint8
|
||||
_, err := fmt.Sscanf(vals[3], "m=%d,t=%d,p=%d", &m, &t, &p)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return &HashParts{
|
||||
Memory: m,
|
||||
Iterations: t,
|
||||
Parallelism: p,
|
||||
SaltBase64: vals[4],
|
||||
HashBase64: vals[5],
|
||||
}
|
||||
}
|
||||
7
internal/service/rand.go
Normal file
7
internal/service/rand.go
Normal file
@@ -0,0 +1,7 @@
|
||||
package service
|
||||
|
||||
import "crypto/rand"
|
||||
|
||||
// RandReader is a cryptographically secure random number generator.
|
||||
// It is an alias to crypto/rand.Reader for convenience and testability.
|
||||
var RandReader = rand.Reader
|
||||
@@ -5,15 +5,16 @@ import (
|
||||
"fmt"
|
||||
"log"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
_ "github.com/go-sql-driver/mysql"
|
||||
)
|
||||
|
||||
// ConnectDatabase connects to MySQL and returns a *sql.DB handle.
|
||||
// The caller is responsible for calling db.Close() when done.
|
||||
func ConnectDatabase(host string, port int, user string, pass string, dbName string) (*sql.DB, error) {
|
||||
func ConnectDatabase(driver, host string, port int, user string, pass string, dbName string) (*sql.DB, error) {
|
||||
dsn := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s", user, pass, host, port, dbName)
|
||||
db, err := sql.Open("mysql", dsn)
|
||||
db, err := sql.Open(driver, dsn)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("[Error] Failed to connect to MySQL: %v", err)
|
||||
}
|
||||
@@ -23,6 +24,11 @@ func ConnectDatabase(host string, port int, user string, pass string, dbName str
|
||||
return nil, fmt.Errorf("[Error] Failed to ping MySQL: %v", err)
|
||||
}
|
||||
|
||||
// Set connection pool parameters
|
||||
db.SetMaxOpenConns(25)
|
||||
db.SetMaxIdleConns(5)
|
||||
db.SetConnMaxLifetime(15 * time.Minute)
|
||||
|
||||
return db, nil
|
||||
}
|
||||
|
||||
@@ -35,21 +41,11 @@ func CloseDatabase(db *sql.DB) {
|
||||
}
|
||||
}
|
||||
|
||||
// DatabaseExists checks if a database exists
|
||||
func DatabaseExists(db *sql.DB, dbName string) bool {
|
||||
var count int
|
||||
err := db.QueryRow("SELECT COUNT(*) FROM information_schema.schemata WHERE schema_name = ?", dbName).Scan(&count)
|
||||
if err != nil {
|
||||
log.Printf("[Warning] Failed to check database existence: %v", err)
|
||||
return false
|
||||
}
|
||||
return count > 0
|
||||
}
|
||||
|
||||
// TableExists checks if a table exists in the current database
|
||||
func TableExists(db *sql.DB, tableName string) bool {
|
||||
func TableExists(db *sql.DB, dbPrefix, tableName string) bool {
|
||||
var count int
|
||||
err := db.QueryRow("SELECT COUNT(*) FROM information_schema.tables WHERE table_schema = DATABASE() AND table_name = ?", tableName).Scan(&count)
|
||||
queryTableName := dbPrefix + tableName
|
||||
err := db.QueryRow("SELECT COUNT(*) FROM information_schema.tables WHERE table_schema = DATABASE() AND table_name = ?", queryTableName).Scan(&count)
|
||||
if err != nil {
|
||||
log.Printf("[Warning] Failed to check table existence: %v", err)
|
||||
return false
|
||||
@@ -57,38 +53,13 @@ func TableExists(db *sql.DB, tableName string) bool {
|
||||
return count > 0
|
||||
}
|
||||
|
||||
// CreateDatabase creates a new database
|
||||
func CreateDatabase(db *sql.DB, dbName string) error {
|
||||
if !isValidName(dbName) {
|
||||
return fmt.Errorf("invalid database name: %s", dbName)
|
||||
}
|
||||
_, err := db.Exec(fmt.Sprintf("CREATE DATABASE `%s`", dbName))
|
||||
if err != nil {
|
||||
log.Printf("[Warning] Failed to create database: %v", err)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// DropDatabase drops a database
|
||||
func DropDatabase(db *sql.DB, dbName string) error {
|
||||
if !isValidName(dbName) {
|
||||
return fmt.Errorf("invalid database name: %s", dbName)
|
||||
}
|
||||
_, err := db.Exec(fmt.Sprintf("DROP DATABASE `%s`", dbName))
|
||||
if err != nil {
|
||||
log.Printf("[Warning] Failed to drop database: %v", err)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// CreateTable creates a new table in the current database
|
||||
func CreateTable(db *sql.DB, tableName string, tableDef string) error {
|
||||
func CreateTable(db *sql.DB, dbPrefix, tableName string, tableDef string) error {
|
||||
if !isValidName(tableName) {
|
||||
return fmt.Errorf("invalid table name: %s", tableName)
|
||||
}
|
||||
_, err := db.Exec(fmt.Sprintf("CREATE TABLE `%s` (%s)", tableName, tableDef))
|
||||
queryTableName := dbPrefix + tableName
|
||||
_, err := db.Exec(fmt.Sprintf("CREATE TABLE `%s` (%s)", queryTableName, tableDef))
|
||||
if err != nil {
|
||||
log.Printf("[Warning] Failed to create table: %v", err)
|
||||
return err
|
||||
@@ -97,11 +68,12 @@ func CreateTable(db *sql.DB, tableName string, tableDef string) error {
|
||||
}
|
||||
|
||||
// DropTable drops a table from the current database
|
||||
func DropTable(db *sql.DB, tableName string) error {
|
||||
func DropTable(db *sql.DB, dbPrefix, tableName string) error {
|
||||
if !isValidName(tableName) {
|
||||
return fmt.Errorf("invalid table name: %s", tableName)
|
||||
}
|
||||
_, err := db.Exec(fmt.Sprintf("DROP TABLE `%s`", tableName))
|
||||
queryTableName := dbPrefix + tableName
|
||||
_, err := db.Exec(fmt.Sprintf("DROP TABLE `%s`", queryTableName))
|
||||
if err != nil {
|
||||
log.Printf("[Warning] Failed to drop table: %v", err)
|
||||
return err
|
||||
@@ -110,7 +82,7 @@ func DropTable(db *sql.DB, tableName string) error {
|
||||
}
|
||||
|
||||
// InsertRow inserts a new row into a table
|
||||
func InsertRow(db *sql.DB, tableName string, rowData map[string]interface{}) error {
|
||||
func InsertRow(db *sql.DB, dbPrefix, tableName string, rowData map[string]interface{}) error {
|
||||
if !isValidName(tableName) {
|
||||
return fmt.Errorf("invalid table name: %s", tableName)
|
||||
}
|
||||
@@ -125,8 +97,9 @@ func InsertRow(db *sql.DB, tableName string, rowData map[string]interface{}) err
|
||||
params = append(params, val)
|
||||
}
|
||||
|
||||
queryTableName := dbPrefix + tableName
|
||||
placeholders := strings.Repeat("?, ", len(params)-1) + "?"
|
||||
query := fmt.Sprintf("INSERT INTO `%s` (%s) VALUES (%s)", tableName, strings.Join(cols, ", "), placeholders)
|
||||
query := fmt.Sprintf("INSERT INTO `%s` (%s) VALUES (%s)", queryTableName, strings.Join(cols, ", "), placeholders)
|
||||
|
||||
_, err := db.Exec(query, params...)
|
||||
if err != nil {
|
||||
@@ -137,7 +110,7 @@ func InsertRow(db *sql.DB, tableName string, rowData map[string]interface{}) err
|
||||
}
|
||||
|
||||
// UpdateRow updates a row in a table
|
||||
func UpdateRow(db *sql.DB, tableName string, rowData map[string]interface{}, where string, whereArgs ...interface{}) error {
|
||||
func UpdateRow(db *sql.DB, dbPrefix, tableName string, rowData map[string]interface{}, where string, whereArgs ...interface{}) error {
|
||||
if !isValidName(tableName) {
|
||||
return fmt.Errorf("invalid table name: %s", tableName)
|
||||
}
|
||||
@@ -146,7 +119,8 @@ func UpdateRow(db *sql.DB, tableName string, rowData map[string]interface{}, whe
|
||||
log.Printf("[Warning] Failed to build update query: %v", err)
|
||||
return err
|
||||
}
|
||||
query := fmt.Sprintf("UPDATE `%s` SET %s WHERE %s", tableName, setCols, where)
|
||||
queryTableName := dbPrefix + tableName
|
||||
query := fmt.Sprintf("UPDATE `%s` SET %s WHERE %s", queryTableName, setCols, where)
|
||||
_, err = db.Exec(query, append(setVals, whereArgs...)...)
|
||||
if err != nil {
|
||||
log.Printf("[Warning] Failed to update row: %v", err)
|
||||
@@ -176,11 +150,12 @@ func buildUpdateQuery(rowData map[string]interface{}) (string, []interface{}, er
|
||||
}
|
||||
|
||||
// DeleteRow deletes a row from a table
|
||||
func DeleteRow(db *sql.DB, tableName string, where string, whereArgs ...interface{}) error {
|
||||
func DeleteRow(db *sql.DB, dbPrefix, tableName string, where string, whereArgs ...interface{}) error {
|
||||
if !isValidName(tableName) {
|
||||
return fmt.Errorf("invalid table name: %s", tableName)
|
||||
}
|
||||
query := fmt.Sprintf("DELETE FROM `%s` WHERE %s", tableName, where)
|
||||
queryTableName := dbPrefix + tableName
|
||||
query := fmt.Sprintf("DELETE FROM `%s` WHERE %s", queryTableName, where)
|
||||
_, err := db.Exec(query, whereArgs...)
|
||||
if err != nil {
|
||||
log.Printf("[Warning] Failed to delete row: %v", err)
|
||||
@@ -198,7 +173,7 @@ func isValidName(s string) bool {
|
||||
|
||||
// QueryRows queries rows from a table with optional WHERE clause.
|
||||
// Example: rows, err := QueryRows(db, "users", "*", "age > ? AND status = ?", 18, "active")
|
||||
func QueryRows(db *sql.DB, tableName string, columns string, where string, whereArgs ...interface{}) (*sql.Rows, error) {
|
||||
func QueryRows(db *sql.DB, dbPrefix, tableName string, columns string, where string, whereArgs ...interface{}) (*sql.Rows, error) {
|
||||
if !isValidName(tableName) {
|
||||
return nil, fmt.Errorf("invalid table name: %s", tableName)
|
||||
}
|
||||
@@ -208,11 +183,12 @@ func QueryRows(db *sql.DB, tableName string, columns string, where string, where
|
||||
columns = "*"
|
||||
}
|
||||
|
||||
queryTableName := dbPrefix + tableName
|
||||
var query string
|
||||
if where != "" {
|
||||
query = fmt.Sprintf("SELECT %s FROM `%s` WHERE %s", columns, tableName, where)
|
||||
query = fmt.Sprintf("SELECT %s FROM `%s` WHERE %s", columns, queryTableName, where)
|
||||
} else {
|
||||
query = fmt.Sprintf("SELECT %s FROM `%s`", columns, tableName)
|
||||
query = fmt.Sprintf("SELECT %s FROM `%s`", columns, queryTableName)
|
||||
}
|
||||
|
||||
rows, err := db.Query(query, whereArgs...)
|
||||
|
||||
@@ -17,6 +17,13 @@ type Config struct {
|
||||
Maxheaderbytes int `json:"max_header_bytes"`
|
||||
} `json:"advanced"`
|
||||
} `json:"server"`
|
||||
Password struct {
|
||||
Memory uint32 `json:"memory"`
|
||||
Iterations uint32 `json:"iterations"`
|
||||
Parallelism uint8 `json:"parallelism"`
|
||||
SaltLength uint32 `json:"salt_length"`
|
||||
KeyLength uint32 `json:"key_length"`
|
||||
} `json:"password"`
|
||||
Database struct {
|
||||
Driver string `json:"driver"`
|
||||
Host string `json:"host"`
|
||||
@@ -28,16 +35,20 @@ type Config struct {
|
||||
} `json:"database"`
|
||||
}
|
||||
|
||||
// ErrorPageData Data model for error page template
|
||||
type ErrorPageData struct {
|
||||
StatusCode int
|
||||
Title string
|
||||
Message string
|
||||
// Argon2Params defines the parameters for Argon2id (adjustable based on server performance)
|
||||
type Argon2Params struct {
|
||||
Memory uint32 // Memory usage (KiB), recommended 64*1024 = 64MB
|
||||
Iterations uint32 // Time cost, recommended 1-3
|
||||
Parallelism uint8 // Number of parallel threads, recommended 2-4
|
||||
SaltLength uint32 // Salt length, recommended 16 bytes
|
||||
KeyLength uint32 // Output hash length, recommended 32 bytes
|
||||
}
|
||||
|
||||
// IndexPageData Data model for index page template
|
||||
type IndexPageData struct {
|
||||
StatusCode int
|
||||
Title string
|
||||
I18n string
|
||||
// hashParts Used to parse stored hash strings
|
||||
type HashParts struct {
|
||||
Memory uint32
|
||||
Iterations uint32
|
||||
Parallelism uint8
|
||||
SaltBase64 string
|
||||
HashBase64 string
|
||||
}
|
||||
|
||||
@@ -12,12 +12,8 @@ func getServerAddress(host string, port int) string {
|
||||
return host + ":" + strconv.Itoa(port)
|
||||
}
|
||||
|
||||
func CreateWebService() *http.Server {
|
||||
func CreateWebService(config *Config) *http.Server {
|
||||
log.Printf("[Info] Create web service")
|
||||
config, err := GetConfig()
|
||||
if err != nil {
|
||||
log.Fatalf("[Error] Failed to load config: %v", err)
|
||||
}
|
||||
|
||||
addr := getServerAddress(config.Server.Host, config.Server.Port)
|
||||
|
||||
@@ -38,11 +34,7 @@ func CreateWebService() *http.Server {
|
||||
return server
|
||||
}
|
||||
|
||||
func ListenWebService(server *http.Server) {
|
||||
config, err := GetConfig()
|
||||
if err != nil {
|
||||
log.Fatalf("[Error] Failed to load config: %v", err)
|
||||
}
|
||||
func ListenWebService(config *Config, server *http.Server) {
|
||||
addr := getServerAddress(config.Server.Host, config.Server.Port)
|
||||
if config.Server.Tls.Enabled {
|
||||
log.Printf("[Info] Starting HTTPS server on %s", addr)
|
||||
@@ -60,7 +52,4 @@ func ListenWebService(server *http.Server) {
|
||||
log.Fatalf("[Error] HTTP server failed: %v", err)
|
||||
}
|
||||
}
|
||||
if err != nil && err != http.ErrServerClosed {
|
||||
log.Fatalf("[Error] Web server terminated unexpectedly: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user