Skip to content

Commit f5b6567

Browse files
authored
Merge pull request #2 from git-pkgs/user-agent
Add User-Agent header to OSV API requests
2 parents 8a0c3a3 + 294fc97 commit f5b6567

2 files changed

Lines changed: 64 additions & 0 deletions

File tree

osv/osv.go

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ const (
2323
type Source struct {
2424
httpClient *http.Client
2525
baseURL string
26+
userAgent string
2627
}
2728

2829
// Option configures a Source.
@@ -42,11 +43,19 @@ func WithBaseURL(url string) Option {
4243
}
4344
}
4445

46+
// WithUserAgent sets the User-Agent header for API requests.
47+
func WithUserAgent(ua string) Option {
48+
return func(s *Source) {
49+
s.userAgent = ua
50+
}
51+
}
52+
4553
// New creates a new OSV source.
4654
func New(opts ...Option) *Source {
4755
s := &Source{
4856
httpClient: &http.Client{Timeout: DefaultTimeout},
4957
baseURL: DefaultAPIURL,
58+
userAgent: "vulns",
5059
}
5160
for _, opt := range opts {
5261
opt(s)
@@ -82,6 +91,7 @@ func (s *Source) Query(ctx context.Context, p *purl.PURL) ([]vulns.Vulnerability
8291
return nil, fmt.Errorf("creating request: %w", err)
8392
}
8493
httpReq.Header.Set("Content-Type", "application/json")
94+
httpReq.Header.Set("User-Agent", s.userAgent)
8595

8696
resp, err := s.httpClient.Do(httpReq)
8797
if err != nil {
@@ -141,6 +151,7 @@ func (s *Source) QueryBatch(ctx context.Context, purls []*purl.PURL) ([][]vulns.
141151
return nil, fmt.Errorf("creating request: %w", err)
142152
}
143153
httpReq.Header.Set("Content-Type", "application/json")
154+
httpReq.Header.Set("User-Agent", s.userAgent)
144155

145156
resp, err := s.httpClient.Do(httpReq)
146157
if err != nil {
@@ -174,6 +185,7 @@ func (s *Source) Get(ctx context.Context, id string) (*vulns.Vulnerability, erro
174185
if err != nil {
175186
return nil, fmt.Errorf("creating request: %w", err)
176187
}
188+
httpReq.Header.Set("User-Agent", s.userAgent)
177189

178190
resp, err := s.httpClient.Do(httpReq)
179191
if err != nil {

osv/osv_test.go

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,58 @@ func TestName(t *testing.T) {
168168
}
169169
}
170170

171+
func TestDefaultUserAgent(t *testing.T) {
172+
var gotUA string
173+
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
174+
gotUA = r.Header.Get("User-Agent")
175+
w.Header().Set("Content-Type", "application/json")
176+
_, _ = w.Write([]byte(`{"vulns": []}`))
177+
}))
178+
defer server.Close()
179+
180+
source := New(WithBaseURL(server.URL))
181+
p := purl.MakePURL("npm", "test", "1.0.0")
182+
_, _ = source.Query(context.Background(), p)
183+
184+
if gotUA != "vulns" {
185+
t.Errorf("default User-Agent = %q, want %q", gotUA, "vulns")
186+
}
187+
}
188+
189+
func TestCustomUserAgent(t *testing.T) {
190+
var gotUA string
191+
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
192+
gotUA = r.Header.Get("User-Agent")
193+
w.Header().Set("Content-Type", "application/json")
194+
_, _ = w.Write([]byte(`{"vulns": []}`))
195+
}))
196+
defer server.Close()
197+
198+
source := New(WithBaseURL(server.URL), WithUserAgent("git-pkgs/1.0"))
199+
p := purl.MakePURL("npm", "test", "1.0.0")
200+
_, _ = source.Query(context.Background(), p)
201+
202+
if gotUA != "git-pkgs/1.0" {
203+
t.Errorf("User-Agent = %q, want %q", gotUA, "git-pkgs/1.0")
204+
}
205+
}
206+
207+
func TestUserAgentOnGet(t *testing.T) {
208+
var gotUA string
209+
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
210+
gotUA = r.Header.Get("User-Agent")
211+
w.WriteHeader(http.StatusNotFound)
212+
}))
213+
defer server.Close()
214+
215+
source := New(WithBaseURL(server.URL), WithUserAgent("test-agent"))
216+
_, _ = source.Get(context.Background(), "GHSA-test")
217+
218+
if gotUA != "test-agent" {
219+
t.Errorf("Get User-Agent = %q, want %q", gotUA, "test-agent")
220+
}
221+
}
222+
171223
func loadFixture(t *testing.T, name string) []byte {
172224
t.Helper()
173225
path := filepath.Join("testdata", name)

0 commit comments

Comments
 (0)