refactor: update driver caching mechanism to use generic CacheManager and improve metadata handling

This commit is contained in:
lilong.129
2025-06-22 21:42:50 +08:00
parent e48bbb2271
commit 6cc3c3acb5
5 changed files with 251 additions and 155 deletions

View File

@@ -1 +1 @@
v5.0.0-beta-2506221325
v5.0.0-beta-2506222142

View File

@@ -681,10 +681,10 @@ func (r *SessionRunner) Start(givenVars map[string]interface{}) (summary *TestCa
for _, cached := range uixt.ListCachedDrivers() {
// add WDA/UIA logs to summary
logs := map[string]interface{}{
"uuid": cached.Serial,
"uuid": cached.Key,
}
client := cached.Driver
client := cached.Item
if client.GetDevice().LogEnabled() {
log, err1 := client.StopCaptureLog()
if err1 != nil {

View File

@@ -10,16 +10,184 @@ import (
"github.com/rs/zerolog/log"
)
var driverCache sync.Map // key is serial, value is *CachedXTDriver
// CachedXTDriver wraps XTDriver with additional cache metadata
type CachedXTDriver struct {
Platform string
Serial string
Driver *XTDriver
RefCount int32 // reference count for resource management
// CacheManager provides a generic cache management interface
type CacheManager[T any] struct {
cache sync.Map
name string // cache name for logging
cleanup func(T) error // cleanup function for cached items
}
// NewCacheManager creates a new cache manager
func NewCacheManager[T any](name string, cleanup func(T) error) *CacheManager[T] {
return &CacheManager[T]{
cache: sync.Map{},
name: name,
cleanup: cleanup,
}
}
// CachedItem wraps an item with cache metadata
type CachedItem[T any] struct {
Key string
Item T
RefCount int32
Metadata map[string]interface{} // additional metadata
}
// Get retrieves an item from cache
func (cm *CacheManager[T]) Get(key string) (*CachedItem[T], bool) {
if item, ok := cm.cache.Load(key); ok {
if cached, ok := item.(*CachedItem[T]); ok {
cached.RefCount++
log.Debug().
Str("cache", cm.name).
Str("key", key).
Int32("refCount", cached.RefCount).
Msg("Retrieved item from cache")
return cached, true
}
}
return nil, false
}
// Set stores an item in cache
func (cm *CacheManager[T]) Set(key string, item T, metadata map[string]interface{}) *CachedItem[T] {
cached := &CachedItem[T]{
Key: key,
Item: item,
RefCount: 1,
Metadata: metadata,
}
cm.cache.Store(key, cached)
log.Debug().
Str("cache", cm.name).
Str("key", key).
Msg("Stored item in cache")
return cached
}
// Release decrements reference count and removes item if count reaches zero
func (cm *CacheManager[T]) Release(key string) error {
if item, ok := cm.cache.Load(key); ok {
if cached, ok := item.(*CachedItem[T]); ok {
cached.RefCount--
log.Debug().
Str("cache", cm.name).
Str("key", key).
Int32("refCount", cached.RefCount).
Msg("Released item reference")
// If no more references, clean up and remove from cache
if cached.RefCount <= 0 {
cm.cache.Delete(key)
// Clean up item if cleanup function is provided
if cm.cleanup != nil {
if err := cm.cleanup(cached.Item); err != nil {
log.Warn().Err(err).
Str("cache", cm.name).
Str("key", key).
Msg("Failed to cleanup cached item")
return err
}
}
log.Info().
Str("cache", cm.name).
Str("key", key).
Msg("Cleaned up item from cache")
}
}
}
return nil
}
// Clear removes all items from cache
func (cm *CacheManager[T]) Clear() {
cm.cache.Range(func(key, value interface{}) bool {
if keyStr, ok := key.(string); ok {
if cached, ok := value.(*CachedItem[T]); ok {
// Clean up item if cleanup function is provided
if cm.cleanup != nil {
if err := cm.cleanup(cached.Item); err != nil {
log.Warn().Err(err).
Str("cache", cm.name).
Str("key", keyStr).
Msg("Failed to cleanup cached item")
}
}
log.Debug().
Str("cache", cm.name).
Str("key", keyStr).
Msg("Cleaned up item from cache")
}
cm.cache.Delete(key)
}
return true
})
log.Info().Str("cache", cm.name).Msg("Cleared all cached items")
}
// List returns all cached items
func (cm *CacheManager[T]) List() []CachedItem[T] {
var items []CachedItem[T]
cm.cache.Range(func(key, value interface{}) bool {
if cached, ok := value.(*CachedItem[T]); ok {
items = append(items, *cached)
}
return true
})
return items
}
// GetOrCreate gets an existing item or creates a new one using the provided factory function
func (cm *CacheManager[T]) GetOrCreate(key string, factory func() (T, map[string]interface{}, error)) (T, error) {
// Check cache first
if cached, ok := cm.Get(key); ok {
return cached.Item, nil
}
// Create new item
item, metadata, err := factory()
if err != nil {
var zero T
return zero, fmt.Errorf("failed to create item: %w", err)
}
// Store in cache
cached := cm.Set(key, item, metadata)
return cached.Item, nil
}
// Size returns the number of items in cache
func (cm *CacheManager[T]) Size() int {
count := 0
cm.cache.Range(func(key, value interface{}) bool {
count++
return true
})
return count
}
// Use cache manager for XTDriver caching
var driverCacheManager = NewCacheManager("xt-driver", cleanupXTDriver)
// cleanupXTDriver cleans up XTDriver resources
func cleanupXTDriver(driver *XTDriver) error {
if driver != nil && driver.IDriver != nil {
if err := driver.DeleteSession(); err != nil {
log.Warn().Err(err).Msg("Failed to delete driver session during cleanup")
return err
}
}
return nil
}
// CachedXTDriver is an alias for CachedItem[*XTDriver] for backward compatibility
type CachedXTDriver = CachedItem[*XTDriver]
// DriverCacheConfig holds configuration for driver creation
type DriverCacheConfig struct {
Platform string
@@ -30,67 +198,48 @@ type DriverCacheConfig struct {
// GetOrCreateXTDriver gets an existing driver from cache or creates a new one
func GetOrCreateXTDriver(config DriverCacheConfig) (*XTDriver, error) {
// If serial is specified, check cache first
if config.Serial != "" {
cacheKey := config.Serial
if cachedItem, ok := driverCache.Load(cacheKey); ok {
if cached, ok := cachedItem.(*CachedXTDriver); ok {
log.Info().Str("serial", cached.Serial).Msg("Using cached XTDriver")
// Increment reference count
cached.RefCount++
return cached.Driver, nil
}
}
}
// If no serial specified, try to find existing driver
// Handle empty serial case - try to find existing driver first
if config.Serial == "" {
if driver := findCachedDriver(config.Platform); driver != nil {
return driver, nil
}
}
// Create new driver (will auto-detect serial if empty)
driverExt, err := createXTDriverWithConfig(config)
if err != nil {
return nil, fmt.Errorf("failed to create XTDriver: %w", err)
}
// Get actual serial from the created driver
actualSerial := driverExt.GetDevice().UUID()
// Check if a driver with this actual serial already exists in cache
if cachedItem, ok := driverCache.Load(actualSerial); ok {
if cached, ok := cachedItem.(*CachedXTDriver); ok {
log.Info().Str("serial", actualSerial).Msg("Found existing cached XTDriver with detected serial")
// Clean up the newly created driver since we have a cached one
if err := driverExt.DeleteSession(); err != nil {
log.Warn().Err(err).Str("serial", actualSerial).Msg("Failed to delete newly created driver session")
}
// Increment reference count and return cached driver
cached.RefCount++
return cached.Driver, nil
// Use shared cache manager's GetOrCreate functionality
return driverCacheManager.GetOrCreate(config.Serial, func() (*XTDriver, map[string]interface{}, error) {
// Create new driver
driverExt, err := createXTDriverWithConfig(config)
if err != nil {
return nil, nil, fmt.Errorf("failed to create XTDriver: %w", err)
}
}
// Cache the new driver with actual serial
cached := &CachedXTDriver{
Platform: config.Platform,
Driver: driverExt,
Serial: actualSerial,
RefCount: 1,
}
driverCache.Store(actualSerial, cached)
// Get actual serial from the created driver
actualSerial := driverExt.GetDevice().UUID()
log.Info().
Str("platform", config.Platform).
Str("serial", actualSerial).
Msg("Created and cached new XTDriver")
// Check if a driver with actual serial already exists (for empty serial case)
if config.Serial == "" && actualSerial != "" {
if existingCached, ok := driverCacheManager.Get(actualSerial); ok {
// Clean up the newly created driver since we have a cached one
if err := driverExt.DeleteSession(); err != nil {
log.Warn().Err(err).Str("serial", actualSerial).Msg("Failed to delete newly created driver session")
}
return existingCached.Item, existingCached.Metadata, nil
}
}
return driverExt, nil
// Create metadata
metadata := map[string]interface{}{
"platform": config.Platform,
"serial": actualSerial,
}
log.Info().
Str("platform", config.Platform).
Str("serial", actualSerial).
Msg("Created and cached new XTDriver")
return driverExt, metadata, nil
})
}
// createXTDriverWithConfig creates a new XTDriver based on configuration
@@ -184,94 +333,41 @@ func createXTDriverWithConfig(config DriverCacheConfig) (*XTDriver, error) {
// ReleaseXTDriver decrements reference count and removes from cache when count reaches zero
func ReleaseXTDriver(serial string) error {
if cachedItem, ok := driverCache.Load(serial); ok {
if cached, ok := cachedItem.(*CachedXTDriver); ok {
cached.RefCount--
log.Debug().
Str("serial", serial).
Int32("refCount", cached.RefCount).
Msg("Released XTDriver reference")
// If no more references, clean up and remove from cache
if cached.RefCount <= 0 {
driverCache.Delete(serial)
// Clean up driver resources if driver has underlying IDriver
if cached.Driver != nil && cached.Driver.IDriver != nil {
if err := cached.Driver.DeleteSession(); err != nil {
log.Warn().Err(err).Str("serial", serial).Msg("Failed to delete driver session")
}
}
log.Info().Str("serial", serial).Msg("Cleaned up XTDriver from cache")
}
}
}
return nil
return driverCacheManager.Release(serial)
}
// CleanupAllDrivers cleans up all cached drivers
func CleanupAllDrivers() {
driverCache.Range(func(key, value interface{}) bool {
if serial, ok := key.(string); ok {
if cached, ok := value.(*CachedXTDriver); ok {
// Clean up driver resources if driver has underlying IDriver
if cached.Driver != nil && cached.Driver.IDriver != nil {
if err := cached.Driver.DeleteSession(); err != nil {
log.Warn().Err(err).Str("serial", serial).Msg("Failed to delete driver session")
}
}
log.Info().Str("serial", serial).Msg("Cleaned up XTDriver from cache")
}
driverCache.Delete(serial)
}
return true
})
driverCacheManager.Clear()
}
// ListCachedDrivers returns information about all cached drivers
func ListCachedDrivers() []CachedXTDriver {
var drivers []CachedXTDriver
driverCache.Range(func(key, value interface{}) bool {
if cached, ok := value.(*CachedXTDriver); ok {
drivers = append(drivers, *cached)
}
return true
})
return drivers
return driverCacheManager.List()
}
// findCachedDriver searches for a cached driver by platform
// If platform is empty, returns any available driver
func findCachedDriver(platform string) *XTDriver {
var foundDriver *XTDriver
driverCache.Range(func(key, value interface{}) bool {
serial, ok := key.(string)
if !ok {
return true // continue iteration
}
cachedItems := driverCacheManager.List()
cached, ok := value.(*CachedXTDriver)
if !ok {
return true // continue iteration
}
for _, cachedItem := range cachedItems {
cachedPlatform, _ := cachedItem.Metadata["platform"].(string)
// If platform is specified, match platform; otherwise use any available driver
if platform == "" || cached.Platform == platform {
foundDriver = cached.Driver
cached.RefCount++
if platform != "" {
log.Debug().Str("platform", platform).Str("serial", serial).Msg("Using cached XTDriver by platform")
} else {
log.Debug().Str("serial", serial).Msg("Using any available cached XTDriver")
if platform == "" || cachedPlatform == platform {
// Increment reference count by getting from cache
if refreshedItem, ok := driverCacheManager.Get(cachedItem.Key); ok {
if platform != "" {
log.Debug().Str("platform", platform).Str("serial", cachedItem.Key).Msg("Using cached XTDriver by platform")
} else {
log.Debug().Str("serial", cachedItem.Key).Msg("Using any available cached XTDriver")
}
return refreshedItem.Item
}
return false // stop iteration
}
return true // continue iteration
})
return foundDriver
}
return nil
}
// setupXTDriver initializes an XTDriver based on the platform and serial.
@@ -310,12 +406,14 @@ func RegisterXTDriver(serial string, driver *XTDriver) error {
return fmt.Errorf("driver cannot be nil")
}
cached := &CachedXTDriver{
Driver: driver,
Serial: serial,
RefCount: 1,
// Create metadata
metadata := map[string]interface{}{
"platform": "external", // Mark as externally registered
"serial": serial,
}
driverCache.Store(serial, cached)
// Store in cache using shared cache manager
driverCacheManager.Set(serial, driver, metadata)
log.Info().
Str("serial", serial).
@@ -343,8 +441,8 @@ func getXTDriverFromCache(driver IDriver) *XTDriver {
// Get XTDriver from cache using device UUID as serial
cachedDrivers := ListCachedDrivers()
for _, cached := range cachedDrivers {
if cached.Serial == deviceUUID {
return cached.Driver
if serial, _ := cached.Metadata["serial"].(string); serial == deviceUUID {
return cached.Item
}
}

View File

@@ -33,7 +33,7 @@ func TestGetOrCreateXTDriver_EmptySerial_AutoDetect(t *testing.T) {
// Verify that a driver was created and cached with actual serial
drivers := ListCachedDrivers()
assert.Len(t, drivers, 1)
assert.NotEmpty(t, drivers[0].Serial) // Serial should be populated with actual device serial
assert.NotEmpty(t, drivers[0].Key) // Serial should be populated with actual device serial
}
}
@@ -57,7 +57,7 @@ func TestGetOrCreateXTDriver_EmptySerial_DefaultPlatform(t *testing.T) {
// Verify that a driver was created and cached with actual serial
drivers := ListCachedDrivers()
assert.Len(t, drivers, 1)
assert.NotEmpty(t, drivers[0].Serial) // Serial should be populated with actual device serial
assert.NotEmpty(t, drivers[0].Key) // Serial should be populated with actual device serial
}
}
@@ -168,9 +168,9 @@ func TestRegisterXTDriver_Success(t *testing.T) {
// Verify driver is cached
drivers := ListCachedDrivers()
assert.Len(t, drivers, 1)
assert.Equal(t, "external_001", drivers[0].Serial)
assert.Equal(t, "external_001", drivers[0].Key)
assert.Equal(t, int32(1), drivers[0].RefCount)
assert.Equal(t, xtDriver, drivers[0].Driver)
assert.Equal(t, xtDriver, drivers[0].Item)
}
func TestReleaseXTDriver_NonExistentSerial(t *testing.T) {
@@ -255,9 +255,9 @@ func TestListCachedDrivers_Multiple(t *testing.T) {
// Verify driver information
serials := make(map[string]bool)
for _, cached := range drivers {
serials[cached.Serial] = true
serials[cached.Key] = true
assert.Equal(t, int32(1), cached.RefCount)
assert.NotNil(t, cached.Driver)
assert.NotNil(t, cached.Item)
}
assert.True(t, serials["list_test_1"])
assert.True(t, serials["list_test_2"])
@@ -402,7 +402,7 @@ func TestIntegrationExample_TraditionalWay(t *testing.T) {
// Verify registration
drivers := ListCachedDrivers()
assert.Len(t, drivers, 1)
assert.Equal(t, "integration_002", drivers[0].Serial)
assert.Equal(t, "integration_002", drivers[0].Key)
// Clean up
err = ReleaseXTDriver("integration_002")
@@ -555,11 +555,9 @@ func TestCacheReferenceCountManagement(t *testing.T) {
assert.Len(t, drivers, 1)
assert.Equal(t, int32(1), drivers[0].RefCount)
// Simulate multiple references by manually incrementing
if cachedItem, ok := driverCache.Load(serial); ok {
if cached, ok := cachedItem.(*CachedXTDriver); ok {
cached.RefCount++
}
// Simulate multiple references by getting from cache (which increments ref count)
if cachedItem, ok := driverCacheManager.Get(serial); ok {
assert.NotNil(t, cachedItem.Item)
}
// Verify ref count increased

View File

@@ -327,7 +327,7 @@ func NewMCPSuccessResponse(message string, actionTool ActionTool) *mcp.CallToolR
"message": message,
}
// Add all tool-specific fields at the same level
// Add tool-specific fields if provided
toolData := convertToolToData(actionTool)
for key, value := range toolData {
response[key] = value
@@ -336,7 +336,7 @@ func NewMCPSuccessResponse(message string, actionTool ActionTool) *mcp.CallToolR
return marshalToMCPResult(response)
}
// convertToolToData converts tool struct to map[string]any for Data field
// convertToolToData converts tool struct to map for response
func convertToolToData(tool interface{}) map[string]any {
data := make(map[string]any)
@@ -381,7 +381,7 @@ func convertToolToData(tool interface{}) map[string]any {
return data
}
// NewMCPErrorResponse creates an error response
// NewMCPErrorResponse creates an error MCP response
func NewMCPErrorResponse(message string) *mcp.CallToolResult {
response := map[string]any{
"success": false,
@@ -419,7 +419,7 @@ func GenerateReturnSchema(toolStruct interface{}) map[string]string {
for i := 0; i < structType.NumField(); i++ {
field := structType.Field(i)
// Skip embedded MCPResponse fields (though they shouldn't exist now)
// Skip embedded MCPResponse fields
if field.Type.Name() == "MCPResponse" {
continue
}