From 88930c5558d9ee3f31dc2a9723dca3b4a6432099 Mon Sep 17 00:00:00 2001 From: Oleg Broslavsky Date: Mon, 25 Mar 2019 22:23:48 +0700 Subject: [PATCH] Fix multiple NotificationQuery bugs - Switch to public `ConnectSWbemServices` - Set `NotificationQuery.Decoder.Dereferencer` to allow work with ref fields - Fix query timeout calculation - Fix timeout error check --- notification_query.go | 42 ++++++++++++--------------------- notification_query_test.go | 48 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 63 insertions(+), 27 deletions(-) diff --git a/notification_query.go b/notification_query.go index 5236a41..e93f5c4 100644 --- a/notification_query.go +++ b/notification_query.go @@ -75,7 +75,7 @@ func (q *NotificationQuery) SetNotificationTimeout(t time.Duration) { q.queryTimeoutMs = -1 return } - q.queryTimeoutMs = int64(t / time.Microsecond) + q.queryTimeoutMs = int64(t / time.Millisecond) } // SetConnectServerArgs sets `SWbemLocator.ConnectServer` args. Args are @@ -119,16 +119,21 @@ func (q *NotificationQuery) StartNotifications() (err error) { defer comshim.Done() // Connect to WMI service. - service, err := createWMIConnection(q.connectServerArgs...) + service, err := ConnectSWbemServices(q.connectServerArgs...) if err != nil { return fmt.Errorf("failed to connect WMI service; %s", err) } - defer service.Release() + defer func() { + if clErr := service.Close(); clErr != nil { + err = multierror.Append(err, clErr) + } + }() + q.Dereferencer = service // Subscribe to the events. ExecNotificationQuery call must have that flags // and no other. sWbemEventSource, err := oleutil.CallMethod( - service, + service.sWbemServices, "ExecNotificationQuery", q.query, "WQL", @@ -166,6 +171,7 @@ func (q *NotificationQuery) StartNotifications() (err error) { if err := q.Unmarshal(event, e.Interface()); err != nil { return fmt.Errorf("failed to unmarshal event; %s", err) } + _ = eventIUnknown.Clear() // Nah. We can't handle it anyway. // Send to the user. sent := trySend(reflectedResChan, reflectedDoneChan, e.Elem()) @@ -187,28 +193,6 @@ func (q *NotificationQuery) Stop() { q.state = stateStopped } -func createWMIConnection(connectServerArgs ...interface{}) (wmi *ole.IDispatch, err error) { - sWbemLocatorIUnknown, err := oleutil.CreateObject("WbemScripting.SWbemLocator") - if err != nil { - return nil, fmt.Errorf("failed to create SWbemLocator; %s", err) - } else if sWbemLocatorIUnknown == nil { - return nil, ErrNilCreateObject - } - defer sWbemLocatorIUnknown.Release() - - sWbemLocatorIDispatch, err := sWbemLocatorIUnknown.QueryInterface(ole.IID_IDispatch) - if err != nil { - return nil, fmt.Errorf("SWbemLocator.QueryInterface failed ; %s", err) - } - defer sWbemLocatorIDispatch.Release() - - serviceRaw, err := oleutil.CallMethod(sWbemLocatorIDispatch, "ConnectServer", connectServerArgs...) - if err != nil { - return nil, fmt.Errorf("SWbemLocator.ConnectServer failed; %s", err) - } - return serviceRaw.ToIDispatch(), nil -} - type state int const ( @@ -219,7 +203,11 @@ const ( func isTimeoutError(err error) bool { oleErr, ok := err.(*ole.OleError) - return ok && oleErr.Code() == wbemErrTimedOut + if !ok { + return false + } + exception, ok := oleErr.SubError().(ole.EXCEPINFO) + return ok && exception.SCODE() == wbemErrTimedOut } func isChannelTypeOK(eventCh interface{}) bool { diff --git a/notification_query_test.go b/notification_query_test.go index 301497c..25428a1 100644 --- a/notification_query_test.go +++ b/notification_query_test.go @@ -131,3 +131,51 @@ func TestNotificationQuery_StartStop(t *testing.T) { t.Errorf("Failed to stop query in 5x NotificationTimeout's") } } + +func TestNotificationQuery_StopWithNoEvents(t *testing.T) { + type event struct { + Created uint64 `wmi:"TIME_CREATED"` + } + + // Create a query that will never receive an event. + resultCh := make(chan event) + queryString := ` +SELECT * FROM __InstanceModificationEvent +WHERE TargetInstance ISA 'Win32_LocalTime' AND TargetInstance.Hour = 25` // Should never happen. + + query, err := NewNotificationQuery(resultCh, queryString) + if err != nil { + t.Fatalf("Failed to create NotificationQuery; %s", err) + } + query.SetNotificationTimeout(100 * time.Millisecond) + + var wg sync.WaitGroup + wg.Add(1) + go func() { + if err := query.StartNotifications(); err != nil { + t.Errorf("Notification query error; %s", err) + } + wg.Done() + }() + + // We can't get an event, but emulate some tries. + select { + case e := <-resultCh: + t.Errorf("OMFG! Got timer event with Hour == 25; %+v", e) + case <-time.After(500 * time.Millisecond): + // Ok. As intended. + } + + // Stop the query and confirm routine is dead. + query.Stop() + done := make(chan struct{}) + go func() { + wg.Wait() + close(done) + }() + select { + case <-done: + case <-time.After(500 * time.Millisecond): + t.Errorf("Failed to stop query in 5x NotificationTimeout's") + } +}