Skip to content

Commit be7c549

Browse files
committed
drivers implement context
1 parent 478acc7 commit be7c549

8 files changed

Lines changed: 44 additions & 29 deletions

File tree

certgraph.go

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
package main
1212

1313
import (
14+
"context"
1415
"embed"
1516
"encoding/json"
1617
"flag"
@@ -389,7 +390,9 @@ func visit(domainNode *graph.DomainNode) {
389390

390391
// perform cert search
391392
// TODO do pagination in multiple threads to not block on long searches
392-
results, err := certDriver.QueryDomain(domainNode.Domain)
393+
ctx, cancel := context.WithTimeout(context.Background(), config.timeout)
394+
defer cancel()
395+
results, err := certDriver.QueryDomain(ctx, domainNode.Domain)
393396
if err != nil {
394397
// this is VERY common to error, usually this is a DNS or tcp connection related issue
395398
// we will skip the domain if we can't query it
@@ -430,7 +433,7 @@ func visit(domainNode *graph.DomainNode) {
430433
processedCerts[fp] = true
431434

432435
// get cert details
433-
certResult, err := results.QueryCert(fp)
436+
certResult, err := results.QueryCert(ctx, fp)
434437
if err != nil {
435438
v("QueryCert", err)
436439
continue

driver/censys/censys.go

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ package censys
77

88
import (
99
"bytes"
10+
"context"
1011
"encoding/base64"
1112
"encoding/json"
1213
"flag"
@@ -66,8 +67,8 @@ func (c *censysCertDriver) GetRelated() ([]string, error) {
6667
return nil, nil // Return nil instead of empty slice for better memory efficiency
6768
}
6869

69-
func (c *censysCertDriver) QueryCert(fp fingerprint.Fingerprint) (*driver.CertResult, error) {
70-
return c.driver.QueryCert(fp)
70+
func (c *censysCertDriver) QueryCert(ctx context.Context, fp fingerprint.Fingerprint) (*driver.CertResult, error) {
71+
return c.driver.QueryCert(ctx, fp)
7172
}
7273

7374
// TODO support pagination
@@ -188,7 +189,7 @@ func (d *censys) jsonRequest(method, url string, request, response interface{})
188189
return nil
189190
}
190191

191-
func (d *censys) QueryDomain(domain string) (driver.Result, error) {
192+
func (d *censys) QueryDomain(ctx context.Context, domain string) (driver.Result, error) {
192193
results := &censysCertDriver{
193194
host: domain,
194195
fingerprints: make(driver.FingerprintMap),
@@ -218,7 +219,7 @@ func (d *censys) QueryDomain(domain string) (driver.Result, error) {
218219
return results, nil
219220
}
220221

221-
func (d *censys) QueryCert(fp fingerprint.Fingerprint) (*driver.CertResult, error) {
222+
func (d *censys) QueryCert(ctx context.Context, fp fingerprint.Fingerprint) (*driver.CertResult, error) {
222223
certNode := new(driver.CertResult)
223224
certNode.Fingerprint = fp
224225
certNode.Domains = make([]string, 0, 5)

driver/crtsh/crtsh.go

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
package crtsh
99

1010
import (
11+
"context"
1112
"database/sql"
1213
"fmt"
1314
"log"
@@ -57,8 +58,8 @@ func (c *crtshCertDriver) GetRelated() ([]string, error) {
5758
return nil, nil // Return nil instead of empty slice for better memory efficiency
5859
}
5960

60-
func (c *crtshCertDriver) QueryCert(fp fingerprint.Fingerprint) (*driver.CertResult, error) {
61-
return c.driver.QueryCert(fp)
61+
func (c *crtshCertDriver) QueryCert(ctx context.Context, fp fingerprint.Fingerprint) (*driver.CertResult, error) {
62+
return c.driver.QueryCert(ctx, fp)
6263
}
6364

6465
// Driver creates a new CT driver for crt.sh
@@ -108,7 +109,7 @@ func (d *crtsh) setSQLTimeout(sec float64) error {
108109
return err
109110
}
110111

111-
func (d *crtsh) QueryDomain(domain string) (driver.Result, error) {
112+
func (d *crtsh) QueryDomain(ctx context.Context, domain string) (driver.Result, error) {
112113
results := &crtshCertDriver{
113114
host: domain,
114115
fingerprints: make(driver.FingerprintMap),
@@ -164,7 +165,7 @@ func (d *crtsh) QueryDomain(domain string) (driver.Result, error) {
164165
if debug {
165166
log.Printf("QueryDomain try %d: %s", try, queryStr)
166167
}
167-
rows, err = d.db.Query(queryStr, d.includeExpired, d.includeSubdomains, d.queryLimit, domain)
168+
rows, err = d.db.QueryContext(ctx, queryStr, d.includeExpired, d.includeSubdomains, d.queryLimit, domain)
168169
if err == nil {
169170
break
170171
}
@@ -202,7 +203,7 @@ func (d *crtsh) QueryDomain(domain string) (driver.Result, error) {
202203
return results, nil
203204
}
204205

205-
func (d *crtsh) QueryCert(fp fingerprint.Fingerprint) (*driver.CertResult, error) {
206+
func (d *crtsh) QueryCert(ctx context.Context, fp fingerprint.Fingerprint) (*driver.CertResult, error) {
206207
certNode := new(driver.CertResult)
207208
certNode.Fingerprint = fp
208209
certNode.Domains = make([]string, 0, 5)
@@ -217,7 +218,7 @@ func (d *crtsh) QueryCert(fp fingerprint.Fingerprint) (*driver.CertResult, error
217218
for try < 5 {
218219
// this is a hack while crt.sh gets there stuff together
219220
try++
220-
rows, err = d.db.Query(queryStr, fp[:])
221+
rows, err = d.db.QueryContext(ctx, queryStr, fp[:])
221222
if err == nil {
222223
break
223224
}

driver/driver.go

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
package driver
33

44
import (
5+
"context"
56
"crypto/x509"
67
"sort"
78
"strings"
@@ -10,8 +11,6 @@ import (
1011
"github.com/lanrat/certgraph/status"
1112
)
1213

13-
// TODO add context instead of timeout on all requests
14-
1514
// Drivers contains all the drivers that have been registered
1615
var Drivers []string
1716

@@ -27,7 +26,7 @@ type Driver interface {
2726
// QueryDomain is the main entrypoint for Driver Searching
2827
// The domain provided will return a CertDriver instance which can be used to query the
2928
// certificates for the provided domain using the driver
30-
QueryDomain(domain string) (Result, error)
29+
QueryDomain(ctx context.Context, domain string) (Result, error)
3130

3231
// GetName returns the name of the driver
3332
GetName() string
@@ -46,7 +45,7 @@ type Result interface {
4645
GetFingerprints() (FingerprintMap, error)
4746

4847
// QueryCert returns the details of the provided certificate or an error if not found
49-
QueryCert(fp fingerprint.Fingerprint) (*CertResult, error)
48+
QueryCert(ctx context.Context, fp fingerprint.Fingerprint) (*CertResult, error)
5049
}
5150

5251
// FingerprintMap stores a mapping of domains to Fingerprints returned from the driver

driver/example.go

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,14 @@
11
package driver
22

3-
import "fmt"
3+
import (
4+
"context"
5+
"fmt"
6+
)
47

58
// Example provides a simple entrypoint to test a driver on an individual domain
69
func Example(domain string, driver Driver) error {
7-
certDriver, err := driver.QueryDomain(domain)
10+
ctx := context.Background()
11+
certDriver, err := driver.QueryDomain(ctx, domain)
812
if err != nil {
913
return err
1014
}
@@ -27,7 +31,7 @@ func Example(domain string, driver Driver) error {
2731
for domain, fingerprints := range fingerprintMap {
2832
for i := range fingerprints {
2933
fmt.Printf("%s: %s\n", domain, fingerprints[i].HexString())
30-
cert, err := certDriver.QueryCert(fingerprints[i])
34+
cert, err := certDriver.QueryCert(ctx, fingerprints[i])
3135
if err != nil {
3236
return err
3337
}

driver/http/http.go

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
package http
33

44
import (
5+
"context"
56
"crypto/tls"
67
"fmt"
78
"net"
@@ -58,7 +59,7 @@ func (c *httpCertDriver) GetRelated() ([]string, error) {
5859

5960
// QueryCert retrieves certificate details for a specific fingerprint.
6061
// Returns an error if the certificate was not found in this HTTP query.
61-
func (c *httpCertDriver) QueryCert(fp fingerprint.Fingerprint) (*driver.CertResult, error) {
62+
func (c *httpCertDriver) QueryCert(ctx context.Context, fp fingerprint.Fingerprint) (*driver.CertResult, error) {
6263
cert, found := c.certs[fp]
6364
if found {
6465
return cert, nil
@@ -121,10 +122,15 @@ func (d *httpDriver) newHTTPCertDriver() *httpCertDriver {
121122

122123
// QueryDomain discovers certificates for a domain through HTTPS connections.
123124
// Follows redirects and collects certificates from all encountered servers.
124-
func (d *httpDriver) QueryDomain(host string) (driver.Result, error) {
125+
func (d *httpDriver) QueryDomain(ctx context.Context, host string) (driver.Result, error) {
125126
results := d.newHTTPCertDriver()
126127

127-
resp, err := results.client.Get(fmt.Sprintf("https://%s", host))
128+
req, err := http.NewRequestWithContext(ctx, "GET", fmt.Sprintf("https://%s", host), nil)
129+
if err != nil {
130+
return nil, fmt.Errorf("failed to create request: %w", err)
131+
}
132+
133+
resp, err := results.client.Do(req)
128134
fullStatus := status.CheckNetErr(err)
129135
if fullStatus != status.GOOD {
130136
return results, err // in some rare cases this error can be ignored

driver/multi/multi_driver.go

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
package multi
33

44
import (
5+
"context"
56
"errors"
67
"fmt"
78
"strings"
@@ -39,14 +40,14 @@ func (d *multiDriver) GetName() string {
3940

4041
// QueryDomain executes domain queries against all drivers concurrently.
4142
// Returns a merged result containing certificates and status information from all drivers.
42-
func (d *multiDriver) QueryDomain(domain string) (driver.Result, error) {
43+
func (d *multiDriver) QueryDomain(ctx context.Context, domain string) (driver.Result, error) {
4344
r := newResult(domain)
44-
var group errgroup.Group
45+
group, ctx := errgroup.WithContext(ctx)
4546
for _, d := range d.drivers {
4647
goFunc := func(localDriver driver.Driver) func() error {
4748
return func() error {
4849
return func(localDriver driver.Driver) error {
49-
result, err := localDriver.QueryDomain(domain)
50+
result, err := localDriver.QueryDomain(ctx, domain)
5051
if err != nil {
5152
return err
5253
}
@@ -104,9 +105,9 @@ func (c *multiResult) add(r driver.Result) error {
104105

105106
// QueryCert attempts to retrieve certificate details from any of the drivers.
106107
// Returns the first successful result found among the combined drivers.
107-
func (c *multiResult) QueryCert(fp fingerprint.Fingerprint) (*driver.CertResult, error) {
108+
func (c *multiResult) QueryCert(ctx context.Context, fp fingerprint.Fingerprint) (*driver.CertResult, error) {
108109
for _, result := range c.results {
109-
cr, err := result.QueryCert(fp)
110+
cr, err := result.QueryCert(ctx, fp)
110111
if err != nil {
111112
return nil, err
112113
}

driver/smtp/smtp.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ func (c *smtpCertDriver) GetRelated() ([]string, error) {
6060

6161
// QueryCert retrieves certificate details for a specific fingerprint.
6262
// Returns an error if the certificate was not found in this SMTP query.
63-
func (c *smtpCertDriver) QueryCert(fp fingerprint.Fingerprint) (*driver.CertResult, error) {
63+
func (c *smtpCertDriver) QueryCert(ctx context.Context, fp fingerprint.Fingerprint) (*driver.CertResult, error) {
6464
cert, found := c.certs[fp]
6565
if found {
6666
return cert, nil
@@ -119,7 +119,7 @@ func (d *smtpDriver) smtpGetCerts(host string) ([]*x509.Certificate, error) {
119119

120120
// QueryDomain discovers certificates for a domain through SMTP STARTTLS.
121121
// Also performs MX record lookups to find related mail server domains.
122-
func (d *smtpDriver) QueryDomain(host string) (driver.Result, error) {
122+
func (d *smtpDriver) QueryDomain(ctx context.Context, host string) (driver.Result, error) {
123123
results := &smtpCertDriver{
124124
host: host,
125125
status: make(status.Map),

0 commit comments

Comments
 (0)