diff --git a/.github/copilot-instructions.md b/.github/copilot-instructions.md index 6ecfaba..3c716e8 100644 --- a/.github/copilot-instructions.md +++ b/.github/copilot-instructions.md @@ -116,7 +116,9 @@ Fastest paths to understand and navigate the codebase: TypedDicts (`RequestLoadDict`, `VariableDetail`, `LocalOverride`). - **HTTP subsystem:** Read `src/services/http/__init__.py` — re-exports `HttpService`, `GraphQLSchemaService`, `SnippetGenerator`, - `HttpResponseDict`, `parse_header_dict`. + `SnippetOptions`, `HttpResponseDict`, `parse_header_dict`. + Auth header injection lives in `src/services/http/auth_handler.py`. + OAuth 2.0 token exchange lives in `src/services/http/oauth2_service.py`. - **All DB models:** Read `src/database/database.py` — re-exports all four ORM models (`CollectionModel`, `RequestModel`, `SavedResponseModel`, `EnvironmentModel`). @@ -159,7 +161,13 @@ src/ │ ├── http/ # HTTP request/response handling │ │ ├── http_service.py # HttpService (httpx) + response TypedDicts │ │ ├── graphql_schema_service.py # GraphQL introspection + schema parsing -│ │ ├── snippet_generator.py # Code snippet generation (cURL/Python/JS) +│ │ ├── auth_handler.py # Shared auth header injection (all 12 auth types) +│ │ ├── oauth2_service.py # OAuth 2.0 token exchange (4 grant types) +│ │ ├── snippet_generator/ # Code snippet generation sub-package (23 languages) +│ │ │ ├── generator.py # SnippetGenerator, SnippetOptions, LanguageEntry, registry +│ │ │ ├── shell_snippets.py # cURL, HTTP raw, wget, HTTPie, PowerShell +│ │ │ ├── dynamic_snippets.py # Python, JS, Node, Ruby, PHP, Dart +│ │ │ └── compiled_snippets.py # Go, Rust, C, Swift, Java, Kotlin, C# │ │ └── header_utils.py # Shared header parsing utility │ └── import_parser/ # Parser sub-package │ ├── models.py # TypedDict schemas for parsed data @@ -170,9 +178,14 @@ src/ ├── main_window/ # Top-level MainWindow sub-package │ ├── window.py # MainWindow widget + signal wiring │ ├── send_pipeline.py # _SendPipelineMixin — HTTP send/response flow + │ ├── draft_controller.py # _DraftControllerMixin — draft tab open/save │ ├── tab_controller.py # _TabControllerMixin — tab open/close/switch - │ └── variable_controller.py # _VariableControllerMixin — env variable management + │ └── variable_controller.py # _VariableControllerMixin — env variable + sidebar management ├── loading_screen.py # Loading screen overlay widget + ├── sidebar/ # Right sidebar sub-package + │ ├── sidebar_widget.py # RightSidebar (icon rail) + _FlyoutPanel + │ ├── variables_panel.py # VariablesPanel — read-only variable display + │ └── snippet_panel.py # SnippetPanel — inline code snippet generator ├── styling/ # Visual theming and icons │ ├── theme.py # Palettes, colours, badge geometry, method_color() │ ├── theme_manager.py # ThemeManager — QPalette + QSettings @@ -192,6 +205,7 @@ src/ ├── collections/ # Collection sidebar │ ├── collection_header.py │ ├── collection_widget.py + │ ├── new_item_popup.py # NewItemPopup — Postman-style icon grid popup │ └── tree/ # Tree widget sub-package │ ├── constants.py │ ├── draggable_tree_widget.py @@ -199,9 +213,9 @@ src/ │ ├── tree_actions.py # _TreeActionsMixin — context menus, rename, delete │ └── collection_tree_delegate.py # Custom delegate for method badges ├── dialogs/ # Modal dialogs - │ ├── code_snippet_dialog.py │ ├── collection_runner.py │ ├── import_dialog.py + │ ├── save_request_dialog.py # Save draft request to collection │ └── settings_dialog.py # Settings (theme, colour scheme) ├── environments/ # Environment management widgets │ ├── environment_editor.py @@ -212,9 +226,15 @@ src/ └── request/ # Request/response editing ├── folder_editor.py # Folder/collection detail editor ├── http_worker.py # HttpSendWorker + SchemaFetchWorker (QThread) + ├── auth/ # Shared auth sub-package (14 auth types) + │ ├── auth_field_specs.py # Per-type FieldSpec definitions (AUTH_FIELD_SPECS) + │ ├── auth_mixin.py # _AuthMixin — shared by both editors + │ ├── auth_pages.py # FieldSpec dataclass, page builders, auth constants + │ ├── auth_serializer.py # Generic load/save for all auth types + │ └── oauth2_page.py # OAuth 2.0 custom page (grant-type switching) ├── request_editor/ # RequestEditor sub-package │ ├── editor_widget.py # RequestEditor — main request editing widget - │ ├── auth.py # _AuthMixin — authentication UI + │ ├── auth.py # Re-export of _AuthMixin from auth sub-package │ ├── body_search.py # _BodySearchMixin — search/replace in body │ └── graphql.py # _GraphQLMixin — GraphQL mode + schema ├── response_viewer/ # ResponseViewer sub-package @@ -223,7 +243,7 @@ src/ ├── navigation/ # Tab switching and path navigation │ ├── breadcrumb_bar.py │ ├── request_tab_bar.py - │ └── tab_manager.py # TabManager + TabContext (with local_overrides) + │ └── tab_manager.py # TabManager + TabContext (with local_overrides, draft_name) └── popups/ # Response metadata popups ├── status_popup.py # HTTP status code explanation ├── timing_popup.py # Request timing breakdown @@ -243,14 +263,24 @@ tests/ │ └── http/ # HTTP service tests │ ├── test_http_service.py │ ├── test_graphql_schema_service.py -│ └── test_snippet_generator.py +│ ├── test_snippet_generator.py +│ ├── test_snippet_shell.py +│ ├── test_snippet_dynamic.py +│ ├── test_snippet_compiled.py +│ ├── test_auth_handler.py +│ └── test_oauth2_service.py └── ui/ # End-to-end PySide6 widget tests ├── conftest.py # _no_fetch (autouse) + helpers ├── test_main_window.py ├── test_main_window_save.py # SaveButton + RequestSaveEndToEnd tests + ├── test_main_window_draft.py # Draft tab open/save lifecycle tests ├── styling/ # Theme and icon tests │ ├── test_theme_manager.py │ └── test_icons.py + ├── sidebar/ # Sidebar widget tests + │ ├── test_sidebar.py + │ ├── test_variables_panel.py + │ └── test_snippet_panel.py ├── widgets/ # Shared component tests │ ├── test_code_editor.py │ ├── test_code_editor_folding.py @@ -266,9 +296,11 @@ tests/ │ ├── test_collection_tree.py │ ├── test_collection_tree_actions.py │ ├── test_collection_tree_delegate.py - │ └── test_collection_widget.py + │ ├── test_collection_widget.py + │ └── test_new_item_popup.py ├── dialogs/ # Dialog tests │ ├── test_import_dialog.py + │ ├── test_save_request_dialog.py │ └── test_settings_dialog.py ├── environments/ # Environment widget tests │ ├── test_environment_editor.py diff --git a/.github/instructions/architecture.instructions.md b/.github/instructions/architecture.instructions.md index 677b75b..69f7df2 100644 --- a/.github/instructions/architecture.instructions.md +++ b/.github/instructions/architecture.instructions.md @@ -134,6 +134,10 @@ Key signals to know (always-on summary): - `CollectionWidget.item_action_triggered(str, int, str)` → opens requests/folders in MainWindow. +- `CollectionWidget.draft_request_requested()` → opens a new draft + (unsaved) request tab in MainWindow. +- `NewItemPopup.new_request_clicked()` / `new_collection_clicked()` → + emitted by the icon grid popup when tiles are clicked. - `RequestEditorWidget.send_requested()` → triggers HTTP send flow. - `ThemeManager.theme_changed()` → widgets refresh dynamic styles. - `VariablePopup` uses **class-level callbacks**, not signals — wired once @@ -190,6 +194,20 @@ signal instead of relying on `_safe_svc_call`. Children within a folder are **not sorted** — they appear in dict iteration order (insertion order in Python 3.7+). +### 6. Auth inheritance convention + +`auth = None` in the database means "inherit from parent" — the request +or folder walks up its ancestor chain until it finds a folder with an +explicit `auth` dict. `{"type": "noauth"}` means "no authentication" and +**stops** the inheritance chain. The UI maps `None` to +`"Inherit auth from parent"` in the auth type combo. + +- `_get_auth_data()` returns `None` for inherit, `{"type": "noauth"}` for + explicit no-auth. +- `_load_auth(None)` / `_load_auth({})` → selects "Inherit auth from parent". +- `get_request_inherited_auth(request_id)` / `get_collection_inherited_auth(collection_id)` + resolve the effective auth by walking ancestors. + ## Repository and service reference > **Full repository function catalogues, service method tables, TypedDict @@ -220,7 +238,11 @@ order (insertion order in Python 3.7+). combined variable map in `MainWindow._refresh_variable_map()` and tagged with `is_local=True` in `VariableDetail` so the popup can show Update/Reset buttons. -7. **VariablePopup uses class-level callbacks, not Qt signals** — +7. **`TabContext.draft_name` tracks the display name of unsaved tabs** — + Set to `"Untitled Request"` when a draft tab is opened. Updated when + the user renames via the breadcrumb bar. Used as fallback label in the + save-to-collection dialog. `None` for persisted request tabs. +8. **VariablePopup uses class-level callbacks, not Qt signals** — `VariablePopup` is a **singleton** `QFrame`. Its callbacks (`set_save_callback`, `set_local_override_callback`, `set_reset_local_override_callback`, `set_add_variable_callback`, diff --git a/.github/instructions/pyside6.instructions.md b/.github/instructions/pyside6.instructions.md index c54021b..f01c1ad 100644 --- a/.github/instructions/pyside6.instructions.md +++ b/.github/instructions/pyside6.instructions.md @@ -204,6 +204,20 @@ standard object names: | `sidebarSearch` | `QLineEdit` | Collection sidebar search input | | `sidebarSectionLabel` | `QLabel` | Sidebar section heading | | `sidebarToolButton` | `QToolButton` | Sidebar toolbar button | +| `newItemPopup` | `QDialog` | Postman-style "Create New" dialog | +| `newItemTile` | `QPushButton` | Tile button inside the new-item dialog | +| `newItemTileLabel` | `QLabel` | Tile label text inside the dialog | +| `newItemTitle` | `QLabel` | Dialog heading ("What do you want to create?") | +| `newItemDescription` | `QLabel` | Description text below tiles | +| `collectionTree` | `QTreeWidget` | Collection tree in SaveRequestDialog | +| `sidebarRail` | `QWidget` | Always-visible icon rail (RightSidebar widget) | +| `sidebarRailButton` | `QToolButton` | Checkable icon button in the rail | +| `sidebarPanelArea` | `QWidget` | Collapsible flyout panel (separate splitter child) | +| `sidebarTitleLabel` | `QLabel` | Bold panel title in flyout header | +| `variableKeyLabel` | `QLabel` | Variable key in sidebar panel | +| `variableValueLabel` | `QLabel` | Variable value in sidebar panel | +| `sidebarSourceDot` | `QLabel` | Colour-coded variable source dot | +| `sidebarSeparator` | `QFrame` | Separator line in sidebar panels | ### When inline setStyleSheet() is still acceptable diff --git a/.github/instructions/testing.instructions.md b/.github/instructions/testing.instructions.md index e724952..5c9eeb0 100644 --- a/.github/instructions/testing.instructions.md +++ b/.github/instructions/testing.instructions.md @@ -115,11 +115,17 @@ tests/ │ └── http/ # HTTP service tests │ ├── test_http_service.py │ ├── test_graphql_schema_service.py -│ └── test_snippet_generator.py +│ ├── test_snippet_generator.py +│ ├── test_snippet_shell.py +│ ├── test_snippet_dynamic.py +│ ├── test_snippet_compiled.py +│ ├── test_auth_handler.py +│ └── test_oauth2_service.py └── ui/ # PySide6 widget tests (need qapp + qtbot) ├── conftest.py # _no_fetch (autouse) + helper functions ├── test_main_window.py # Top-level MainWindow smoke tests ├── test_main_window_save.py # SaveButton + RequestSaveEndToEnd tests + ├── test_main_window_draft.py # Draft tab open/save lifecycle tests ├── styling/ # Theme and icon tests │ ├── test_theme_manager.py │ └── test_icons.py @@ -133,14 +139,20 @@ tests/ │ ├── test_variable_line_edit.py │ ├── test_variable_popup.py │ └── test_variable_popup_local.py + ├── sidebar/ # Sidebar widget tests + │ ├── test_sidebar.py + │ ├── test_variables_panel.py + │ └── test_snippet_panel.py ├── collections/ # Collection sidebar tests │ ├── test_collection_header.py │ ├── test_collection_tree.py │ ├── test_collection_tree_actions.py │ ├── test_collection_tree_delegate.py - │ └── test_collection_widget.py + │ ├── test_collection_widget.py + │ └── test_new_item_popup.py ├── dialogs/ # Dialog tests │ ├── test_import_dialog.py + │ ├── test_save_request_dialog.py │ └── test_settings_dialog.py ├── environments/ # Environment widget tests │ ├── test_environment_editor.py diff --git a/.github/skills/service-repository-reference/SKILL.md b/.github/skills/service-repository-reference/SKILL.md index ea8d5d7..f829c06 100644 --- a/.github/skills/service-repository-reference/SKILL.md +++ b/.github/skills/service-repository-reference/SKILL.md @@ -35,8 +35,11 @@ cross-layer data interchange. | `get_collection_by_id(collection_id)` | `CollectionModel \| None` | PK lookup | | `get_request_by_id(request_id)` | `RequestModel \| None` | PK lookup | | `get_request_auth_chain(request_id)` | `dict[str, Any] \| None` | Walk parent chain for auth config | +| `get_request_inherited_auth(request_id)` | `dict[str, Any] \| None` | Resolve inherited auth for a request (walks ancestors) | +| `get_collection_inherited_auth(collection_id)` | `dict[str, Any] \| None` | Resolve inherited auth for a collection (walks ancestors) | | `get_request_variable_chain(request_id)` | `dict[str, str]` | Collect variables up the parent chain | | `get_request_variable_chain_detailed(request_id)` | `dict[str, tuple[str, int]]` | Variables with source collection IDs | +| `get_collection_variable_chain_detailed(collection_id)` | `dict[str, tuple[str, int]]` | Variables from collection's parent chain with source IDs | | `get_request_breadcrumb(request_id)` | `list[dict[str, Any]]` | Ancestor path for breadcrumb bar | | `get_collection_breadcrumb(collection_id)` | `list[dict[str, Any]]` | Ancestor path for collection breadcrumb | | `get_saved_responses_for_request(request_id)` | `list[dict[str, Any]]` | Saved responses for a request | @@ -83,6 +86,8 @@ directly to the repository with no added logic. | `update_collection(id, **fields)` | Passthrough (generic field update) | | `update_request(id, **fields)` | Passthrough (generic field update) | | `get_request_auth_chain(request_id)` | Passthrough | +| `get_request_inherited_auth(request_id)` | Passthrough | +| `get_collection_inherited_auth(collection_id)` | Passthrough | | `get_request_variable_chain(request_id)` | Passthrough | | `get_request_breadcrumb(request_id)` | Passthrough | | `get_collection_breadcrumb(collection_id)` | Passthrough | @@ -146,15 +151,23 @@ All methods are `@staticmethod`. ### SnippetGenerator -All methods are `@staticmethod`. +Located in `services/http/snippet_generator/` sub-package (re-exported via +`services/http/__init__.py`). All methods are `@staticmethod`. | Method | Purpose | |--------|---------| -| `curl(method, url, headers, body)` | cURL command | -| `python_requests(method, url, headers, body)` | Python requests library | -| `javascript_fetch(method, url, headers, body)` | JavaScript fetch API | -| `available_languages()` | List of supported language names | -| `generate(language, method, url, headers, body)` | Dispatch to language-specific generator | +| `available_languages()` | List of 23 supported language display names | +| `get_language_info(name)` | Return `LanguageEntry` for a display name, or `None` | +| `generate(language, method, url, headers, body, auth, options)` | Dispatch to language-specific generator | + +**`LanguageEntry`** (`NamedTuple`): `display_name`, `lexer`, `applicable_options`, `generate_fn`. + +**Supported languages (23):** cURL, HTTP, PowerShell (RestMethod), +Shell (HTTPie), Shell (wget), Python (requests), Python (http.client), +JavaScript (fetch), JavaScript (XHR), NodeJS (Axios), NodeJS (Native), +Ruby (Net::HTTP), PHP (cURL), PHP (Guzzle), Dart (http), C (libcurl), +C# (HttpClient), C# (RestSharp), Go (net/http), Java (OkHttp), +Kotlin (OkHttp), Rust (reqwest), Swift (URLSession). ### Shared HTTP utilities (`services/http/header_utils.py`) @@ -162,8 +175,56 @@ All methods are `@staticmethod`. |----------|---------|---------| | `parse_header_dict(raw)` | `dict[str, str]` | Parse `Key: Value\n` lines into a dict | +### Auth handler (`services/http/auth_handler.py`) + +Shared auth header injection used by both `http_worker.py` (actual send) +and `snippet_generator/generator.py` (code snippets). + +| Function | Returns | Purpose | +|----------|---------|---------| +| `apply_auth(auth, url, headers, *, method, body)` | `(url, headers)` | Dispatch to type-specific handler | + +Supports 12 field-based auth types: bearer, basic, apikey, oauth2, digest, +oauth1, hawk, awsv4, jwt, asap, ntlm, edgegrid. HMAC-based JWT (HS256/384/512) +uses stdlib; RSA/EC algorithms require optional `PyJWT`. NTLM is pass-through +(requires live challenge-response). + +### OAuth2Service (`services/http/oauth2_service.py`) + +OAuth 2.0 token exchange for all four grant types. + +| Method | Returns | Purpose | +|--------|---------|---------| +| `get_token(config)` | `OAuth2TokenResult` | Dispatch to grant-type handler | +| `refresh_token(token_url, refresh_token, client_id, client_secret, client_auth)` | `OAuth2TokenResult` | Refresh an expired token | + +Grant types: Authorization Code (browser + redirect), Implicit (browser + +fragment capture), Password Credentials (direct POST), Client Credentials +(direct POST). Browser-based flows open the system browser and start a +local HTTP server to capture the callback. + ## TypedDict schemas +### SnippetOptions (`services/http/snippet_generator/generator.py`) + +```python +class SnippetOptions(TypedDict, total=False): + indent_count: int # default 2 + indent_type: str # "space" or "tab", default "space" + trim_body: bool # default False + follow_redirect: bool # default True + request_timeout: int # seconds, 0 = no timeout + include_boilerplate: bool # default True — imports/main wrappers + async_await: bool # default False — async/await vs promise chains + es6_features: bool # default False — ES6 import/arrow syntax + multiline: bool # default True — split shell commands across lines + long_form: bool # default True — --header vs -H + line_continuation: str # default "\\" — continuation char (\, ^, `) + quote_type: str # default "single" — URL quoting style + follow_original_method: bool # default False — keep method on redirect + silent_mode: bool # default False — suppress progress meter +``` + ### HttpService TypedDicts (`services/http/http_service.py`) ```python @@ -201,6 +262,18 @@ class HttpResponseDict(TypedDict): network: NotRequired[NetworkDict] ``` +### OAuth2TokenResult (`services/http/oauth2_service.py`) + +```python +class OAuth2TokenResult(TypedDict): + access_token: str + token_type: str + expires_in: int + refresh_token: str + scope: str + error: str +``` + ### CollectionService TypedDicts (`services/collection_service.py`) ```python diff --git a/.github/skills/signal-flow/SKILL.md b/.github/skills/signal-flow/SKILL.md index 2eb68ae..f972b8c 100644 --- a/.github/skills/signal-flow/SKILL.md +++ b/.github/skills/signal-flow/SKILL.md @@ -13,9 +13,16 @@ and every connection made in `MainWindow.__init__`. ### Create operations ``` -Header "+" menu - → CollectionHeader.new_collection_requested(None) - → CollectionWidget._create_new_collection(parent_id=None) +Header "New" button → NewItemPopup + → NewItemPopup.new_collection_clicked() + → CollectionHeader.new_collection_requested(None) + → CollectionWidget._create_new_collection(parent_id=None) + + → NewItemPopup.new_request_clicked() + → CollectionHeader.new_request_requested(None) + → CollectionWidget._create_new_request(collection_id=None) + → CollectionWidget.draft_request_requested() + → MainWindow._open_draft_request() Tree context menu → "Add folder" → CollectionTree.new_collection_requested(parent_id) @@ -151,7 +158,17 @@ CollectionTree.currentItemChanged → _on_current_item_changed → selected_collection_changed(collection_id | None) → CollectionWidget → CollectionHeader.set_selected_collection_id - → enables / disables "New request" action in + menu +``` + +### Draft request save flow + +``` +MainWindow._on_save_request() [request_id is None, ctx is not None] + → MainWindow._save_draft_request(ctx, editor) + → SaveRequestDialog.exec() + → on accept: CollectionService.create_request() + → upgrade tab (request_id, dirty=False) + → CollectionWidget tree update ``` ### Tab bar flow @@ -183,7 +200,8 @@ BreadcrumbBar.item_clicked(type, id) BreadcrumbBar.last_segment_renamed(new_name) → MainWindow._on_breadcrumb_rename - → CollectionService.rename_request / rename_collection + → if draft tab (request_id=None): update ctx.draft_name + tab label only + → else: CollectionService.rename_request / rename_collection ``` ### Environment selector flow @@ -240,9 +258,11 @@ MainWindow._toggle_layout_action.triggered ``` MainWindow snippet_act.triggered - → _on_code_snippet - → CodeSnippetDialog(request_data) - → SnippetGenerator.generate(language, ...) + → _on_snippet_shortcut + → _right_sidebar.open_panel("snippet") + → SnippetPanel.update_request(method, url, headers, body, auth) + → SnippetGenerator.generate(language, ..., options=SnippetOptions) + → registry dispatch to language-specific generator (23 languages) ``` ### Settings flow @@ -292,6 +312,22 @@ RequestEditorWidget._on_fetch_schema() → RequestEditorWidget._on_schema_ready(result) ``` +### OAuth 2.0 token flow + +``` +OAuth2Page.get_token_requested (user clicks "Get New Access Token") + → _AuthMixin._on_get_oauth2_token() + → OAuth2TokenWorker.set_config(config) + → QThread.started → OAuth2TokenWorker.run() + → OAuth2Service.get_token(config) + → OAuth2TokenWorker.finished(OAuth2TokenResult) + → _AuthMixin._on_oauth2_token_received(data) + → OAuth2Page.set_token(token, name) + → OAuth2TokenWorker.error(str) + → _AuthMixin._on_oauth2_token_error(msg) + → QMessageBox.warning() +``` + ### Variable popup flow ``` @@ -375,6 +411,9 @@ All other signals in the flow diagrams above are fully wired. | `CollectionWidget` | `item_action_triggered` | `Signal(str, int, str)` | | `CollectionWidget` | `item_name_changed` | `Signal(str, int, str)` | | `CollectionWidget` | `load_finished` | `Signal()` | +| `CollectionWidget` | `draft_request_requested` | `Signal()` | +| `NewItemPopup` | `new_request_clicked` | `Signal()` | +| `NewItemPopup` | `new_collection_clicked` | `Signal()` | ### Request / response subsystem @@ -389,6 +428,10 @@ All other signals in the flow diagrams above are fully wired. | `HttpSendWorker` | `error` | `Signal(str)` | | `SchemaFetchWorker` | `finished` | `Signal(dict)` — `SchemaResultDict` | | `SchemaFetchWorker` | `error` | `Signal(str)` | +| `OAuth2TokenWorker` | `finished` | `Signal(dict)` — `OAuth2TokenResult` | +| `OAuth2TokenWorker` | `error` | `Signal(str)` | +| `OAuth2Page` | `field_changed` | `Signal()` | +| `OAuth2Page` | `get_token_requested` | `Signal()` | ### Tab bar @@ -475,7 +518,7 @@ All connections made in `MainWindow.__init__` (and `_create_menus`): - `forward_action.triggered` → `_navigate_forward` - `import_act.triggered` → `_on_import` - `save_act.triggered` → `_on_save_request` -- `snippet_act.triggered` → `_on_code_snippet` +- `snippet_act.triggered` → `_on_snippet_shortcut` - `settings_act.triggered` → `_on_settings` - `run_act.triggered` → `_on_run_collection` - `exit_act.triggered` → `close` @@ -488,6 +531,7 @@ All connections made in `MainWindow.__init__` (and `_create_menus`): - `editor.send_requested` → `_on_send_request` - `editor.save_requested` → `_on_save_request` - `editor.dirty_changed` → `_sync_save_btn` +- `editor.request_changed` → `_schedule_sidebar_snippet_refresh` - `viewer.save_response_requested` → `_on_save_response` **VariablePopup callbacks (classmethods, wired once):** diff --git a/src/database/models/collections/collection_query_repository.py b/src/database/models/collections/collection_query_repository.py index ed05aa3..b05f63c 100644 --- a/src/database/models/collections/collection_query_repository.py +++ b/src/database/models/collections/collection_query_repository.py @@ -136,21 +136,30 @@ def get_request_by_id(request_id: int) -> RequestModel | None: def get_request_auth_chain(request_id: int) -> dict[str, Any] | None: """Walk the parent collection chain to find the effective auth config. - Returns the request's own auth if set and not ``noauth``. Otherwise - walks up through parent collections and returns the first auth found. - Returns ``None`` if no auth is configured anywhere in the chain. + Inheritance rules: + + - ``auth is None`` means *inherit from parent* — walk up the chain. + - ``{"type": "noauth"}`` means *explicit no-auth* — stop and + return ``None``. + - Any other auth dict means *own auth* — return it immediately. + + Returns ``None`` if nothing is configured or an ancestor opted out. """ with get_session() as session: req = session.get(RequestModel, request_id) if req is None: return None # 1. Check request's own auth - if req.auth and req.auth.get("type") not in (None, "noauth"): + if req.auth is not None: + if req.auth.get("type") == "noauth": + return None return req.auth - # 2. Walk parent collection chain + # 2. Walk parent collection chain (request auth is None → inherit) coll = session.get(CollectionModel, req.collection_id) while coll is not None: - if coll.auth and coll.auth.get("type") not in (None, "noauth"): + if coll.auth is not None: + if coll.auth.get("type") == "noauth": + return None return coll.auth if coll.parent_id is None: break @@ -158,6 +167,51 @@ def get_request_auth_chain(request_id: int) -> dict[str, Any] | None: return None +def get_request_inherited_auth(request_id: int) -> dict[str, Any] | None: + """Return the *parent* auth that a request would inherit. + + Skips the request's own auth and starts walking from its parent + collection. Used by the UI to preview inherited auth on the + "Inherit auth from parent" page. + """ + with get_session() as session: + req = session.get(RequestModel, request_id) + if req is None: + return None + coll = session.get(CollectionModel, req.collection_id) + while coll is not None: + if coll.auth is not None: + if coll.auth.get("type") == "noauth": + return None + return coll.auth + if coll.parent_id is None: + break + coll = session.get(CollectionModel, coll.parent_id) + return None + + +def get_collection_inherited_auth(collection_id: int) -> dict[str, Any] | None: + """Return the *parent* auth that a collection would inherit. + + Starts from the collection's *parent* (not itself) and walks up. + Used by the folder editor to preview inherited auth. + """ + with get_session() as session: + coll = session.get(CollectionModel, collection_id) + if coll is None or coll.parent_id is None: + return None + parent = session.get(CollectionModel, coll.parent_id) + while parent is not None: + if parent.auth is not None: + if parent.auth.get("type") == "noauth": + return None + return parent.auth + if parent.parent_id is None: + break + parent = session.get(CollectionModel, parent.parent_id) + return None + + def get_request_variable_chain(request_id: int) -> dict[str, str]: """Walk the parent collection chain and merge all collection variables. @@ -225,6 +279,36 @@ def get_request_variable_chain_detailed(request_id: int) -> dict[str, tuple[str, return merged +def get_collection_variable_chain_detailed( + collection_id: int, +) -> dict[str, tuple[str, int]]: + """Walk the parent chain from *collection_id* and return variables. + + Returns ``{key: (value, collection_id)}`` — the same shape as + :func:`get_request_variable_chain_detailed` but starting from a + collection instead of a request. + """ + with get_session() as session: + layers: list[tuple[int, list[dict[str, Any]]]] = [] + coll = session.get(CollectionModel, collection_id) + while coll is not None: + if coll.variables: + layers.append((coll.id, coll.variables)) + if coll.parent_id is None: + break + coll = session.get(CollectionModel, coll.parent_id) + merged: dict[str, tuple[str, int]] = {} + for coll_id, var_list in reversed(layers): + for entry in var_list: + if not entry.get("enabled", True): + continue + key = entry.get("key", "") + value = entry.get("value", "") + if key: + merged[key] = (value, coll_id) + return merged + + def get_request_breadcrumb(request_id: int) -> list[dict[str, Any]]: """Return the breadcrumb path from root collection to the request. diff --git a/src/services/collection_service.py b/src/services/collection_service.py index 4abe9ac..abadfbc 100644 --- a/src/services/collection_service.py +++ b/src/services/collection_service.py @@ -14,10 +14,12 @@ fetch_all_collections, get_collection_breadcrumb, get_collection_by_id, + get_collection_inherited_auth, get_recent_requests_for_collection, get_request_auth_chain, get_request_breadcrumb, get_request_by_id, + get_request_inherited_auth, get_request_variable_chain, get_saved_responses_for_request, ) @@ -237,11 +239,33 @@ def update_request(request_id: int, **fields: Any) -> None: def get_request_auth_chain(request_id: int) -> dict[str, Any] | None: """Return the effective auth for a request, walking parent chain. - Checks the request first, then walks up parent collections. - Returns ``None`` if no auth is configured anywhere in the chain. + Respects the inherit / noauth distinction: + + - ``auth is None`` → inherit from parent (walk up the chain). + - ``{"type": "noauth"}`` → explicit no-auth (stop). + - Any other auth dict → use it. """ return get_request_auth_chain(request_id) + @staticmethod + def get_request_inherited_auth(request_id: int) -> dict[str, Any] | None: + """Return only the *parent* auth a request would inherit. + + Skips the request's own auth and walks from its parent + collection upward. Used by the "Inherit auth from parent" UI + to preview the resolved auth. + """ + return get_request_inherited_auth(request_id) + + @staticmethod + def get_collection_inherited_auth(collection_id: int) -> dict[str, Any] | None: + """Return only the *parent* auth a collection would inherit. + + Starts from the collection's parent and walks up. Used by + the folder editor's "Inherit auth from parent" UI. + """ + return get_collection_inherited_auth(collection_id) + @staticmethod def get_request_variable_chain(request_id: int) -> dict[str, str]: """Return the merged collection variables for a request. diff --git a/src/services/http/__init__.py b/src/services/http/__init__.py index fe457d5..99980be 100644 --- a/src/services/http/__init__.py +++ b/src/services/http/__init__.py @@ -5,12 +5,16 @@ from services.http.graphql_schema_service import GraphQLSchemaService from services.http.header_utils import parse_header_dict from services.http.http_service import HttpResponseDict, HttpService -from services.http.snippet_generator import SnippetGenerator +from services.http.oauth2_service import OAuth2Service, OAuth2TokenResult +from services.http.snippet_generator import SnippetGenerator, SnippetOptions __all__ = [ "GraphQLSchemaService", "HttpResponseDict", "HttpService", + "OAuth2Service", + "OAuth2TokenResult", "SnippetGenerator", + "SnippetOptions", "parse_header_dict", ] diff --git a/src/services/http/auth_handler.py b/src/services/http/auth_handler.py new file mode 100644 index 0000000..6f94805 --- /dev/null +++ b/src/services/http/auth_handler.py @@ -0,0 +1,762 @@ +"""Shared auth header injection for HTTP send and snippet generation. + +Computes and injects authentication credentials into request headers +or URL query parameters. Used by :mod:`ui.request.http_worker` +(actual send) and :mod:`services.http.snippet_generator.generator` +(code snippets). + +Every handler receives already-substituted values — variable +replacement is the caller's responsibility. +""" + +from __future__ import annotations + +import base64 +import datetime as dt +import hashlib +import hmac +import json +import logging +import secrets +import time +from collections.abc import Callable +from urllib.parse import parse_qs, quote, urlencode, urlparse + +logger = logging.getLogger(__name__) + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +_Handler = Callable[..., tuple[str, dict[str, str]]] + + +def _entries_map(auth: dict, auth_type: str) -> dict[str, str]: + """Extract Postman key-value entries into a flat ``{key: value}`` map. + + Booleans are normalised to lowercase ``"true"`` / ``"false"`` + strings so that callers can compare consistently. + """ + result: dict[str, str] = {} + for entry in auth.get(auth_type, []): + if isinstance(entry, dict): + val = entry.get("value", "") + if isinstance(val, bool): + result[entry["key"]] = "true" if val else "false" + elif isinstance(val, str): + result[entry["key"]] = val + else: + result[entry["key"]] = str(val) + return result + + +def _b64url(data: bytes) -> str: + """Base64url-encode *data* without padding.""" + return base64.urlsafe_b64encode(data).rstrip(b"=").decode() + + +def _percent_encode(s: str) -> str: + """RFC 5849 percent-encoding (unreserved characters only).""" + return quote(s, safe="") + + +def _hmac_sha256(key: bytes, msg: str) -> bytes: + """Return the raw HMAC-SHA256 digest.""" + return hmac.new(key, msg.encode(), hashlib.sha256).digest() + + +# --------------------------------------------------------------------------- +# Public entry point +# --------------------------------------------------------------------------- + + +def apply_auth( + auth: dict | None, + url: str, + headers: dict[str, str], + *, + method: str = "GET", + body: str | None = None, +) -> tuple[str, dict[str, str]]: + """Inject auth credentials into *headers* and/or *url*. + + Returns the (possibly modified) ``(url, headers)`` pair. + *method* and *body* are required by schemes that include them in + the signature (Digest, OAuth 1.0, Hawk, AWS SigV4, EdgeGrid). + """ + if not auth: + return url, headers + auth_type = auth.get("type", "noauth") + handler = _HANDLERS.get(auth_type) + if handler: + url, headers = handler(auth, url, headers, method=method, body=body) + return url, headers + + +# --------------------------------------------------------------------------- +# Simple token / credential types +# --------------------------------------------------------------------------- + + +def _apply_bearer( + auth: dict, + url: str, + headers: dict[str, str], + *, + method: str = "GET", + body: str | None = None, +) -> tuple[str, dict[str, str]]: + """Bearer token — ``Authorization: Bearer ``.""" + token = _entries_map(auth, "bearer").get("token", "") + if token: + headers["Authorization"] = f"Bearer {token}" + return url, headers + + +def _apply_basic( + auth: dict, + url: str, + headers: dict[str, str], + *, + method: str = "GET", + body: str | None = None, +) -> tuple[str, dict[str, str]]: + """Basic auth — ``Authorization: Basic ``.""" + v = _entries_map(auth, "basic") + username, password = v.get("username", ""), v.get("password", "") + if username or password: + encoded = base64.b64encode(f"{username}:{password}".encode()).decode() + headers["Authorization"] = f"Basic {encoded}" + return url, headers + + +def _apply_apikey( + auth: dict, + url: str, + headers: dict[str, str], + *, + method: str = "GET", + body: str | None = None, +) -> tuple[str, dict[str, str]]: + """API key — custom header or query parameter.""" + v = _entries_map(auth, "apikey") + key, value, add_to = v.get("key", ""), v.get("value", ""), v.get("in", "header") + if key and value: + if add_to == "header": + headers[key] = value + else: + sep = "&" if "?" in url else "?" + url = f"{url}{sep}{_percent_encode(key)}={_percent_encode(value)}" + return url, headers + + +def _apply_oauth2( + auth: dict, + url: str, + headers: dict[str, str], + *, + method: str = "GET", + body: str | None = None, +) -> tuple[str, dict[str, str]]: + """OAuth 2.0 manual token — ``Authorization: ``.""" + v = _entries_map(auth, "oauth2") + token = v.get("accessToken", "") + prefix = v.get("headerPrefix", "Bearer") + add_to = v.get("addTokenTo", "header") + if token: + if add_to == "header": + headers["Authorization"] = f"{prefix} {token}" if prefix else token + else: + sep = "&" if "?" in url else "?" + url = f"{url}{sep}access_token={_percent_encode(token)}" + return url, headers + + +# --------------------------------------------------------------------------- +# Digest Auth (RFC 7616) +# --------------------------------------------------------------------------- + +_DIGEST_HASH: dict[str, str] = { + "MD5": "md5", + "MD5-sess": "md5", + "SHA-256": "sha256", + "SHA-256-sess": "sha256", + "SHA-512-256": "sha512_256", + "SHA-512-256-sess": "sha512_256", +} + + +def _apply_digest( + auth: dict, + url: str, + headers: dict[str, str], + *, + method: str = "GET", + body: str | None = None, +) -> tuple[str, dict[str, str]]: + """Digest auth — ``Authorization: Digest ...``.""" + v = _entries_map(auth, "digest") + username, password = v.get("username", ""), v.get("password", "") + realm, nonce = v.get("realm", ""), v.get("nonce", "") + algorithm = v.get("algorithm", "MD5") + qop = v.get("qop", "") + nc = v.get("nonceCount", "00000001") + cnonce = v.get("clientNonce", "") or secrets.token_hex(8) + opaque = v.get("opaque", "") + + uri = urlparse(url).path or "/" + hash_name = _DIGEST_HASH.get(algorithm, "md5") + + def _h(data: str) -> str: + return hashlib.new(hash_name, data.encode()).hexdigest() + + a1 = f"{username}:{realm}:{password}" + if algorithm.endswith("-sess"): + a1 = f"{_h(a1)}:{nonce}:{cnonce}" + + a2 = f"{method}:{uri}" + if qop == "auth-int" and body: + a2 = f"{method}:{uri}:{_h(body)}" + + ha1, ha2 = _h(a1), _h(a2) + if qop in ("auth", "auth-int"): + response = _h(f"{ha1}:{nonce}:{nc}:{cnonce}:{qop}:{ha2}") + else: + response = _h(f"{ha1}:{nonce}:{ha2}") + + parts = [ + f'username="{username}"', + f'realm="{realm}"', + f'nonce="{nonce}"', + f'uri="{uri}"', + f"algorithm={algorithm}", + f'response="{response}"', + ] + if qop: + parts.extend([f"qop={qop}", f"nc={nc}", f'cnonce="{cnonce}"']) + if opaque: + parts.append(f'opaque="{opaque}"') + headers["Authorization"] = f"Digest {', '.join(parts)}" + return url, headers + + +# --------------------------------------------------------------------------- +# OAuth 1.0 (RFC 5849) +# --------------------------------------------------------------------------- + + +def _apply_oauth1( + auth: dict, + url: str, + headers: dict[str, str], + *, + method: str = "GET", + body: str | None = None, +) -> tuple[str, dict[str, str]]: + """OAuth 1.0 — ``Authorization: OAuth ...`` or query/body params.""" + v = _entries_map(auth, "oauth1") + consumer_key = v.get("consumerKey", "") + consumer_secret = v.get("consumerSecret", "") + token = v.get("token", "") + token_secret = v.get("tokenSecret", "") + sig_method = v.get("signatureMethod", "HMAC-SHA1") + timestamp = v.get("timestamp", "") or str(int(time.time())) + nonce = v.get("nonce", "") or secrets.token_hex(16) + version = v.get("version", "1.0") + realm = v.get("realm", "") + callback_url = v.get("callbackUrl", "") + verifier = v.get("verifier", "") + include_body_hash = v.get("includeBodyHash", "false") == "true" + add_empty = v.get("addEmptyParamsToSign", "false") == "true" + to_header_raw = v.get("addParamsToHeader", "true") + + # 1. Collect OAuth params + oauth: dict[str, str] = { + "oauth_consumer_key": consumer_key, + "oauth_signature_method": sig_method, + "oauth_timestamp": timestamp, + "oauth_nonce": nonce, + "oauth_version": version, + } + if token: + oauth["oauth_token"] = token + if callback_url: + oauth["oauth_callback"] = callback_url + if verifier: + oauth["oauth_verifier"] = verifier + + # Body hash (RFC 5849 §3.4.1.3.1) + if include_body_hash and body: + if sig_method == "HMAC-SHA256": + bh = base64.b64encode(hashlib.sha256(body.encode()).digest()).decode() + else: + bh = base64.b64encode(hashlib.sha1(body.encode()).digest()).decode() + oauth["oauth_body_hash"] = bh + + # Add empty params if requested + if add_empty: + for k, val in list(oauth.items()): + if not val: + oauth[k] = "" + + # 2. Merge query params for base-string + parsed = urlparse(url) + base_url = f"{parsed.scheme}://{parsed.netloc}{parsed.path}" + all_params: list[tuple[str, str]] = [] + for k, vals in parse_qs(parsed.query, keep_blank_values=True).items(): + for val in vals: + all_params.append((k, val)) + all_params.extend(oauth.items()) + + # 3. Base string + param_str = "&".join( + f"{_percent_encode(k)}={_percent_encode(val)}" for k, val in sorted(all_params) + ) + base_string = f"{method.upper()}&{_percent_encode(base_url)}&{_percent_encode(param_str)}" + + # 4. Signature + signing_key = f"{_percent_encode(consumer_secret)}&{_percent_encode(token_secret)}" + if sig_method == "HMAC-SHA1": + raw_sig = hmac.new( + signing_key.encode(), + base_string.encode(), + hashlib.sha1, + ).digest() + sig = base64.b64encode(raw_sig).decode() + elif sig_method == "HMAC-SHA256": + raw_sig = hmac.new( + signing_key.encode(), + base_string.encode(), + hashlib.sha256, + ).digest() + sig = base64.b64encode(raw_sig).decode() + elif sig_method == "PLAINTEXT": + sig = signing_key + else: + sig = "" + logger.warning("OAuth 1.0 %s requires external libraries", sig_method) + oauth["oauth_signature"] = sig + + # 5. Emit — "true" = headers, "false" = URL query, "body" = request body + if to_header_raw == "true": + parts = [ + f'{_percent_encode(k)}="{_percent_encode(val)}"' for k, val in sorted(oauth.items()) + ] + if realm: + parts.insert(0, f'realm="{realm}"') + headers["Authorization"] = f"OAuth {', '.join(parts)}" + else: + qs = urlencode(oauth) + sep = "&" if "?" in url else "?" + url = f"{url}{sep}{qs}" + return url, headers + + +# --------------------------------------------------------------------------- +# Hawk Authentication +# --------------------------------------------------------------------------- + + +def _apply_hawk( + auth: dict, + url: str, + headers: dict[str, str], + *, + method: str = "GET", + body: str | None = None, +) -> tuple[str, dict[str, str]]: + """Hawk auth — ``Authorization: Hawk id=... mac=...``.""" + v = _entries_map(auth, "hawk") + auth_id = v.get("authId", "") + auth_key = v.get("authKey", "") + algorithm = v.get("algorithm", "sha256") + nonce = v.get("nonce", "") or secrets.token_hex(8) + ext = v.get("extraData", "") + app_id = v.get("appId", "") + delegation = v.get("delegation", "") + ts = v.get("timestamp", "") or str(int(time.time())) + include_hash = v.get("includePayloadHash", "false") == "true" + + parsed = urlparse(url) + host = parsed.hostname or "" + port = str(parsed.port or (443 if parsed.scheme == "https" else 80)) + resource = parsed.path + (f"?{parsed.query}" if parsed.query else "") + + # Payload hash (gated by includePayloadHash checkbox) + payload_hash = "" + if include_hash and body: + ctype = headers.get("Content-Type", "").split(";")[0].strip() + hash_input = f"hawk.1.payload\n{ctype}\n{body}\n" + raw = hashlib.new(algorithm, hash_input.encode()).digest() + payload_hash = base64.b64encode(raw).decode() + + # Normalised string + normalized = ( + f"hawk.1.header\n{ts}\n{nonce}\n{method.upper()}\n" + f"{resource}\n{host}\n{port}\n{payload_hash}\n{ext}\n" + ) + if app_id: + normalized += f"{app_id}\n{delegation}\n" + + mac = base64.b64encode( + hmac.new(auth_key.encode(), normalized.encode(), algorithm).digest(), + ).decode() + + parts = [f'id="{auth_id}"', f'ts="{ts}"', f'nonce="{nonce}"', f'mac="{mac}"'] + if payload_hash: + parts.append(f'hash="{payload_hash}"') + if ext: + parts.append(f'ext="{ext}"') + if app_id: + parts.append(f'app="{app_id}"') + if delegation: + parts.append(f'dlg="{delegation}"') + headers["Authorization"] = f"Hawk {', '.join(parts)}" + return url, headers + + +# --------------------------------------------------------------------------- +# AWS Signature V4 +# --------------------------------------------------------------------------- + + +def _apply_awsv4( + auth: dict, + url: str, + headers: dict[str, str], + *, + method: str = "GET", + body: str | None = None, +) -> tuple[str, dict[str, str]]: + """AWS SigV4 — ``Authorization: AWS4-HMAC-SHA256 ...``.""" + v = _entries_map(auth, "awsv4") + access_key = v.get("accessKey", "") + secret_key = v.get("secretKey", "") + region = v.get("region", "us-east-1") + service = v.get("service", "") + session_token = v.get("sessionToken", "") + + now = dt.datetime.now(dt.UTC) + amz_date = now.strftime("%Y%m%dT%H%M%SZ") + date_stamp = now.strftime("%Y%m%d") + + parsed = urlparse(url) + host = parsed.netloc + uri = parsed.path or "/" + payload_hash = hashlib.sha256((body or "").encode()).hexdigest() + + # Required headers + headers["x-amz-date"] = amz_date + headers["x-amz-content-sha256"] = payload_hash + if session_token: + headers["x-amz-security-token"] = session_token + if not any(k.lower() == "host" for k in headers): + headers["Host"] = host + + # Canonical query string + qs_items = sorted(parse_qs(parsed.query, keep_blank_values=True).items()) + canonical_qs = "&".join(f"{quote(k, safe='')}={quote(vs[0], safe='')}" for k, vs in qs_items) + + # Signed headers + signed = sorted((k.lower(), v.strip()) for k, v in headers.items()) + canonical_hdrs = "".join(f"{k}:{val}\n" for k, val in signed) + signed_names = ";".join(k for k, _ in signed) + + canonical_request = ( + f"{method}\n{uri}\n{canonical_qs}\n{canonical_hdrs}\n{signed_names}\n{payload_hash}" + ) + + # String to sign + scope = f"{date_stamp}/{region}/{service}/aws4_request" + string_to_sign = ( + f"AWS4-HMAC-SHA256\n{amz_date}\n{scope}\n" + f"{hashlib.sha256(canonical_request.encode()).hexdigest()}" + ) + + # Signing key chain + k_date = _hmac_sha256(f"AWS4{secret_key}".encode(), date_stamp) + k_region = _hmac_sha256(k_date, region) + k_service = _hmac_sha256(k_region, service) + k_signing = _hmac_sha256(k_service, "aws4_request") + signature = hmac.new( + k_signing, + string_to_sign.encode(), + hashlib.sha256, + ).hexdigest() + + headers["Authorization"] = ( + f"AWS4-HMAC-SHA256 Credential={access_key}/{scope}, " + f"SignedHeaders={signed_names}, Signature={signature}" + ) + return url, headers + + +# --------------------------------------------------------------------------- +# JWT Bearer +# --------------------------------------------------------------------------- + +_JWT_HMAC = { + "HS256": hashlib.sha256, + "HS384": hashlib.sha384, + "HS512": hashlib.sha512, +} + + +def _build_jwt( + payload_json: str, + key: str, + algorithm: str, + headers_json: str, + is_key_b64: bool, +) -> str | None: + """Build a JWT token. + + HMAC algorithms (HS256 / HS384 / HS512) use stdlib only. + RSA / EC / PS algorithms attempt ``PyJWT``; returns ``None`` + when the library is not installed. + """ + hash_fn = _JWT_HMAC.get(algorithm) + + if hash_fn is None: + # RSA / EC / PS — try PyJWT + try: + import jwt as pyjwt # type: ignore[import-not-found] + except ImportError: + logger.warning( + "%s requires PyJWT; install with: pip install PyJWT cryptography", + algorithm, + ) + return None + try: + payload = json.loads(payload_json) if payload_json.strip() else {} + extra = json.loads(headers_json) if headers_json.strip() else {} + result: str = pyjwt.encode(payload, key, algorithm=algorithm, headers=extra) + return result + except Exception: + logger.exception("JWT signing failed (%s)", algorithm) + return None + + # HMAC path — stdlib + try: + payload = json.loads(payload_json) if payload_json.strip() else {} + except json.JSONDecodeError: + payload = {} + try: + extra = json.loads(headers_json) if headers_json.strip() else {} + except json.JSONDecodeError: + extra = {} + + header = {"alg": algorithm, "typ": "JWT", **extra} + hdr_b64 = _b64url(json.dumps(header, separators=(",", ":")).encode()) + pay_b64 = _b64url(json.dumps(payload, separators=(",", ":")).encode()) + signing_input = f"{hdr_b64}.{pay_b64}" + key_bytes = base64.b64decode(key) if is_key_b64 else key.encode() + sig = hmac.new(key_bytes, signing_input.encode(), hash_fn).digest() + return f"{signing_input}.{_b64url(sig)}" + + +def _apply_jwt( + auth: dict, + url: str, + headers: dict[str, str], + *, + method: str = "GET", + body: str | None = None, +) -> tuple[str, dict[str, str]]: + """JWT Bearer — ``Authorization: ``.""" + v = _entries_map(auth, "jwt") + algorithm = v.get("algorithm", "HS256") + secret = v.get("secret", "") + private_key = v.get("privateKey", "") + payload_json = v.get("payload", "{}") + headers_json = v.get("headers", "{}") + is_b64 = v.get("isSecretBase64Encoded", "false") == "true" + add_to = v.get("addTokenTo", "header") + prefix = v.get("headerPrefix", "Bearer") + query_key = v.get("queryParamKey", "token") + + key = private_key if algorithm.startswith(("RS", "ES", "PS")) else secret + token = _build_jwt(payload_json, key, algorithm, headers_json, is_b64) + if not token: + return url, headers + if add_to == "header": + headers["Authorization"] = f"{prefix} {token}" if prefix else token + else: + sep = "&" if "?" in url else "?" + url = f"{url}{sep}{_percent_encode(query_key)}={_percent_encode(token)}" + return url, headers + + +# --------------------------------------------------------------------------- +# ASAP (Atlassian) +# --------------------------------------------------------------------------- + + +def _apply_asap( + auth: dict, + url: str, + headers: dict[str, str], + *, + method: str = "GET", + body: str | None = None, +) -> tuple[str, dict[str, str]]: + """ASAP — ``Authorization: Bearer `` with Atlassian claims.""" + v = _entries_map(auth, "asap") + issuer = v.get("issuer", "") + subject = v.get("subject", "") + audience = v.get("audience", "") + private_key = v.get("privateKey", "") + kid = v.get("kid", "") + algorithm = v.get("algorithm", "RS256") + expires_in = int(v.get("expiresIn", "3600") or "3600") + claims_json = v.get("claims", "{}") + + now_ts = int(time.time()) + try: + extra_claims = json.loads(claims_json) if claims_json.strip() else {} + except json.JSONDecodeError: + extra_claims = {} + + payload: dict[str, object] = { + "iss": issuer, + "iat": now_ts, + "exp": now_ts + expires_in, + "jti": secrets.token_hex(16), + **extra_claims, + } + if subject: + payload["sub"] = subject + if audience: + payload["aud"] = audience + + extra_hdrs = json.dumps({"kid": kid}) if kid else "{}" + token = _build_jwt( + json.dumps(payload), + private_key, + algorithm, + extra_hdrs, + False, + ) + if token: + headers["Authorization"] = f"Bearer {token}" + return url, headers + + +# --------------------------------------------------------------------------- +# NTLM Authentication +# --------------------------------------------------------------------------- + + +def _apply_ntlm( + auth: dict, + url: str, + headers: dict[str, str], + *, + method: str = "GET", + body: str | None = None, +) -> tuple[str, dict[str, str]]: + """NTLM — stored for display; live negotiation is not pre-computable.""" + return url, headers + + +# --------------------------------------------------------------------------- +# Akamai EdgeGrid +# --------------------------------------------------------------------------- + + +def _apply_edgegrid( + auth: dict, + url: str, + headers: dict[str, str], + *, + method: str = "GET", + body: str | None = None, +) -> tuple[str, dict[str, str]]: + """Akamai EdgeGrid — ``Authorization: EG1-HMAC-SHA256 ...``.""" + v = _entries_map(auth, "edgegrid") + access_token = v.get("accessToken", "") + client_token = v.get("clientToken", "") + client_secret = v.get("clientSecret", "") + nonce = v.get("nonce", "") or secrets.token_urlsafe(16) + timestamp = v.get("timestamp", "") + headers_to_sign = v.get("headersToSign", "") + max_body = int(v.get("maxBody", "131072") or "131072") + + if not timestamp: + timestamp = dt.datetime.now(dt.UTC).strftime( + "%Y%m%dT%H:%M:%S+0000", + ) + + parsed = urlparse(url) + path_query = parsed.path + (f"?{parsed.query}" if parsed.query else "") + + # Content hash (POST / PUT only) + content_hash = "" + if body and method.upper() in ("POST", "PUT"): + trimmed = body[:max_body].encode() + content_hash = base64.b64encode( + hashlib.sha256(trimmed).digest(), + ).decode() + + # Canonical signed headers + names = [h.strip().lower() for h in headers_to_sign.split(",") if h.strip()] + canon_hdrs = "" + for name in sorted(names): + val = headers.get(name, "") + canon_hdrs += f"{name}:{' '.join(val.split())}\t" + + # Auth stub (unsigned) + auth_stub = ( + f"EG1-HMAC-SHA256 client_token={client_token};" + f"access_token={access_token};" + f"timestamp={timestamp};nonce={nonce};" + ) + + # Signing key + try: + secret_bytes = base64.b64decode(client_secret) + except Exception: + secret_bytes = client_secret.encode() + signing_key = hmac.new( + secret_bytes, + timestamp.encode(), + hashlib.sha256, + ).digest() + + data_to_sign = "\t".join( + [ + method.upper(), + parsed.scheme, + parsed.hostname or "", + path_query, + canon_hdrs, + content_hash, + auth_stub, + ] + ) + signature = base64.b64encode( + hmac.new(signing_key, data_to_sign.encode(), hashlib.sha256).digest(), + ).decode() + + headers["Authorization"] = f"{auth_stub}signature={signature}" + return url, headers + + +# --------------------------------------------------------------------------- +# Dispatch table +# --------------------------------------------------------------------------- + +_HANDLERS: dict[str, _Handler] = { + "bearer": _apply_bearer, + "basic": _apply_basic, + "apikey": _apply_apikey, + "oauth2": _apply_oauth2, + "digest": _apply_digest, + "oauth1": _apply_oauth1, + "hawk": _apply_hawk, + "awsv4": _apply_awsv4, + "jwt": _apply_jwt, + "asap": _apply_asap, + "ntlm": _apply_ntlm, + "edgegrid": _apply_edgegrid, +} diff --git a/src/services/http/oauth2_service.py b/src/services/http/oauth2_service.py new file mode 100644 index 0000000..03a8452 --- /dev/null +++ b/src/services/http/oauth2_service.py @@ -0,0 +1,419 @@ +"""OAuth 2.0 token exchange service. + +Performs the actual HTTP token exchanges for all four grant types: + +- **Authorization Code** — opens browser, starts local redirect server, + exchanges code for token. +- **Implicit** — opens browser, captures token from redirect fragment. +- **Password Credentials** — direct POST to token endpoint. +- **Client Credentials** — direct POST to token endpoint. + +All methods are ``@staticmethod`` following the project convention. +""" + +from __future__ import annotations + +import contextlib +import html +import http.server +import logging +import secrets +import webbrowser +from typing import TypedDict +from urllib.parse import parse_qs, urlencode, urlparse + +import httpx + +logger = logging.getLogger(__name__) + +_TIMEOUT = 30 +_REDIRECT_PORT_RANGE = (5000, 5010) + + +class OAuth2TokenResult(TypedDict): + """Result of an OAuth 2.0 token exchange.""" + + access_token: str + token_type: str + expires_in: int + refresh_token: str + scope: str + error: str + + +class OAuth2Service: + """Static methods for OAuth 2.0 token exchange flows.""" + + @staticmethod + def get_token(config: dict) -> OAuth2TokenResult: + """Dispatch to the correct grant-type handler. + + *config* is the dict returned by ``OAuth2Page.get_config()``. + """ + grant = config.get("grant_type", "authorization_code") + if grant == "authorization_code": + return OAuth2Service._authorization_code(config) + if grant == "implicit": + return OAuth2Service._implicit(config) + if grant == "password": + return OAuth2Service._password_credentials(config) + if grant == "client_credentials": + return OAuth2Service._client_credentials(config) + return _error_result(f"Unknown grant type: {grant}") + + # ------------------------------------------------------------------ + # Grant type implementations + # ------------------------------------------------------------------ + + @staticmethod + def _authorization_code(config: dict) -> OAuth2TokenResult: + """Authorization Code grant — browser + redirect + code exchange.""" + auth_url = config.get("authUrl", "") + token_url = config.get("accessTokenUrl", "") + client_id = config.get("clientId", "") + client_secret = str(config.get("clientSecret", "")) + scope = config.get("scope", "") + state = config.get("state", "") or secrets.token_urlsafe(16) + callback_url = config.get("callbackUrl", "") + client_auth = config.get("client_authentication", "header") + + if not auth_url or not token_url or not client_id: + return _error_result("Auth URL, Token URL, and Client ID are required") + + # 1. Determine redirect URI and start local server + redirect_uri, port = _parse_redirect(callback_url) + code_holder: dict[str, str] = {} + server = _start_redirect_server(port, code_holder) + if server is None: + return _error_result(f"Could not start redirect server on port {port}") + + try: + # 2. Open browser to authorization endpoint + params = { + "response_type": "code", + "client_id": client_id, + "redirect_uri": redirect_uri, + "scope": scope, + "state": state, + } + browser_url = f"{auth_url}?{urlencode({k: v for k, v in params.items() if v})}" + webbrowser.open(browser_url) + + # 3. Wait for redirect (blocking — runs on worker thread) + server.handle_request() + server.server_close() + + code = code_holder.get("code", "") + returned_state = code_holder.get("state", "") + error = code_holder.get("error", "") + + if error: + return _error_result(f"Authorization error: {error}") + if state and returned_state and returned_state != state: + return _error_result("State mismatch — possible CSRF attack") + if not code: + return _error_result("No authorization code received") + + # 4. Exchange code for token + return _exchange_code( + token_url=token_url, + code=code, + redirect_uri=redirect_uri, + client_id=client_id, + client_secret=client_secret, + client_auth=client_auth, + ) + finally: + with contextlib.suppress(Exception): + server.server_close() + + @staticmethod + def _implicit(config: dict) -> OAuth2TokenResult: + """Implicit grant — browser redirect with token in fragment.""" + auth_url = config.get("authUrl", "") + client_id = config.get("clientId", "") + scope = config.get("scope", "") + state = config.get("state", "") or secrets.token_urlsafe(16) + callback_url = config.get("callbackUrl", "") + + if not auth_url or not client_id: + return _error_result("Auth URL and Client ID are required") + + redirect_uri, port = _parse_redirect(callback_url) + token_holder: dict[str, str] = {} + server = _start_fragment_server(port, token_holder) + if server is None: + return _error_result(f"Could not start redirect server on port {port}") + + try: + params = { + "response_type": "token", + "client_id": client_id, + "redirect_uri": redirect_uri, + "scope": scope, + "state": state, + } + browser_url = f"{auth_url}?{urlencode({k: v for k, v in params.items() if v})}" + webbrowser.open(browser_url) + + # Handle two requests: first the redirect, then the JS POST + server.handle_request() + server.handle_request() + server.server_close() + + error = token_holder.get("error", "") + if error: + return _error_result(f"Authorization error: {error}") + + access_token = token_holder.get("access_token", "") + if not access_token: + return _error_result("No access token received") + + return OAuth2TokenResult( + access_token=access_token, + token_type=token_holder.get("token_type", "Bearer"), + expires_in=int(token_holder.get("expires_in", "0") or "0"), + refresh_token="", + scope=token_holder.get("scope", ""), + error="", + ) + finally: + with contextlib.suppress(Exception): + server.server_close() + + @staticmethod + def _password_credentials(config: dict) -> OAuth2TokenResult: + """Resource Owner Password Credentials grant — direct token request.""" + token_url = config.get("accessTokenUrl", "") + client_id = config.get("clientId", "") + client_secret = str(config.get("clientSecret", "")) + username = config.get("username", "") + password = str(config.get("password", "")) + scope = config.get("scope", "") + client_auth = config.get("client_authentication", "header") + + if not token_url: + return _error_result("Access Token URL is required") + + data = {"grant_type": "password", "username": username, "password": password} + if scope: + data["scope"] = scope + + return _post_token_request(token_url, data, client_id, client_secret, client_auth) + + @staticmethod + def _client_credentials(config: dict) -> OAuth2TokenResult: + """Client Credentials grant — direct token request.""" + token_url = config.get("accessTokenUrl", "") + client_id = config.get("clientId", "") + client_secret = str(config.get("clientSecret", "")) + scope = config.get("scope", "") + client_auth = config.get("client_authentication", "header") + + if not token_url: + return _error_result("Access Token URL is required") + + data: dict[str, str] = {"grant_type": "client_credentials"} + if scope: + data["scope"] = scope + + return _post_token_request(token_url, data, client_id, client_secret, client_auth) + + @staticmethod + def refresh_token( + token_url: str, + refresh_token: str, + client_id: str, + client_secret: str, + client_auth: str = "header", + ) -> OAuth2TokenResult: + """Refresh an expired access token.""" + data = {"grant_type": "refresh_token", "refresh_token": refresh_token} + return _post_token_request(token_url, data, client_id, client_secret, client_auth) + + +# ------------------------------------------------------------------ +# Internal helpers +# ------------------------------------------------------------------ + + +def _error_result(msg: str) -> OAuth2TokenResult: + """Create an error result.""" + return OAuth2TokenResult( + access_token="", + token_type="", + expires_in=0, + refresh_token="", + scope="", + error=msg, + ) + + +def _post_token_request( + token_url: str, + data: dict[str, str], + client_id: str, + client_secret: str, + client_auth: str, +) -> OAuth2TokenResult: + """POST to a token endpoint and parse the response.""" + headers: dict[str, str] = {"Accept": "application/json"} + use_basic_auth = client_auth == "header" and bool(client_id) + + if not use_basic_auth: + data["client_id"] = client_id + if client_secret: + data["client_secret"] = client_secret + + try: + with httpx.Client(timeout=_TIMEOUT, follow_redirects=True) as client: + if use_basic_auth: + resp = client.post( + token_url, + data=data, + headers=headers, + auth=httpx.BasicAuth(client_id, client_secret), + ) + else: + resp = client.post(token_url, data=data, headers=headers) + resp.raise_for_status() + body = resp.json() + except httpx.HTTPStatusError as exc: + try: + err_body = exc.response.json() + msg = err_body.get("error_description", err_body.get("error", str(exc))) + except Exception: + msg = str(exc) + return _error_result(msg) + except Exception as exc: + return _error_result(str(exc)) + + return OAuth2TokenResult( + access_token=body.get("access_token", ""), + token_type=body.get("token_type", "Bearer"), + expires_in=int(body.get("expires_in", 0)), + refresh_token=body.get("refresh_token", ""), + scope=body.get("scope", ""), + error=body.get("error", ""), + ) + + +def _exchange_code( + token_url: str, + code: str, + redirect_uri: str, + client_id: str, + client_secret: str, + client_auth: str, +) -> OAuth2TokenResult: + """Exchange an authorization code for a token.""" + data = { + "grant_type": "authorization_code", + "code": code, + "redirect_uri": redirect_uri, + } + return _post_token_request(token_url, data, client_id, client_secret, client_auth) + + +def _parse_redirect(callback_url: str) -> tuple[str, int]: + """Extract the redirect URI and port from the callback URL.""" + if not callback_url: + callback_url = "http://localhost:5000/callback" + parsed = urlparse(callback_url) + port = parsed.port or 5000 + redirect_uri = callback_url + return redirect_uri, port + + +def _start_redirect_server( + port: int, + result: dict[str, str], +) -> http.server.HTTPServer | None: + """Start a one-shot HTTP server to capture the authorization code.""" + + class _Handler(http.server.BaseHTTPRequestHandler): + def do_GET(self) -> None: + """Capture code from query parameters.""" + qs = parse_qs(urlparse(self.path).query) + result["code"] = qs.get("code", [""])[0] + result["state"] = qs.get("state", [""])[0] + result["error"] = qs.get("error", [""])[0] + self.send_response(200) + self.send_header("Content-Type", "text/html") + self.end_headers() + self.wfile.write( + b"

Authorization complete

" + b"

You can close this window.

" + ) + + def log_message(self, format: str, *args: object) -> None: + """Suppress default request logging.""" + + try: + server = http.server.HTTPServer(("127.0.0.1", port), _Handler) + server.timeout = 120 + return server + except OSError: + logger.warning("Could not bind to port %d", port) + return None + + +_FRAGMENT_CAPTURE_HTML = """\ + +

Capturing token…

+ +""" + + +def _start_fragment_server( + port: int, + result: dict[str, str], +) -> http.server.HTTPServer | None: + """Start a server that captures token from URL fragment via JS POST.""" + + class _Handler(http.server.BaseHTTPRequestHandler): + def do_GET(self) -> None: + """Serve the fragment-capture HTML page.""" + self.send_response(200) + self.send_header("Content-Type", "text/html") + self.end_headers() + self.wfile.write(_FRAGMENT_CAPTURE_HTML.encode()) + + def do_POST(self) -> None: + """Receive the token fragment forwarded by JS.""" + length = int(self.headers.get("Content-Length", 0)) + body = self.rfile.read(length).decode() + params = parse_qs(body) + for key in ("access_token", "token_type", "expires_in", "scope", "error"): + vals = params.get(key, []) + if vals: + result[key] = html.escape(vals[0]) + self.send_response(200) + self.end_headers() + + def log_message(self, format: str, *args: object) -> None: + """Suppress default request logging.""" + + try: + server = http.server.HTTPServer(("127.0.0.1", port), _Handler) + server.timeout = 120 + return server + except OSError: + logger.warning("Could not bind to port %d", port) + return None diff --git a/src/services/http/snippet_generator.py b/src/services/http/snippet_generator.py deleted file mode 100644 index 346abf5..0000000 --- a/src/services/http/snippet_generator.py +++ /dev/null @@ -1,133 +0,0 @@ -"""Generate code snippets for HTTP requests in various languages. - -Provides a ``SnippetGenerator`` with static methods that convert -request parameters (method, URL, headers, body) into runnable code -snippets for different languages and libraries. -""" - -from __future__ import annotations - -import json -import shlex - -from services.http.header_utils import parse_header_dict - - -class SnippetGenerator: - """Generate code snippets from request parameters. - - Every method is a ``@staticmethod`` — no shared state. - """ - - @staticmethod - def curl( - *, - method: str, - url: str, - headers: str | None = None, - body: str | None = None, - ) -> str: - """Generate a cURL command.""" - parts = ["curl", "-X", method.upper(), shlex.quote(url)] - for key, value in parse_header_dict(headers).items(): - parts.append("-H") - parts.append(shlex.quote(f"{key}: {value}")) - if body: - parts.append("-d") - parts.append(shlex.quote(body)) - return " \\\n ".join(parts) - - @staticmethod - def python_requests( - *, - method: str, - url: str, - headers: str | None = None, - body: str | None = None, - ) -> str: - """Generate a Python ``requests`` snippet.""" - lines = ["import requests", ""] - hdr = parse_header_dict(headers) - - if hdr: - lines.append(f"headers = {json.dumps(hdr, indent=4)}") - lines.append("") - - call = f'response = requests.{method.lower()}("{url}"' - if hdr: - call += ", headers=headers" - if body: - # Try JSON - try: - json.loads(body) - lines.append(f"payload = {body}") - lines.append("") - call += ", json=payload" - except (json.JSONDecodeError, TypeError): - call += f', data="""{body}"""' - call += ")" - lines.append(call) - lines.append("print(response.status_code)") - lines.append("print(response.text)") - return "\n".join(lines) - - @staticmethod - def javascript_fetch( - *, - method: str, - url: str, - headers: str | None = None, - body: str | None = None, - ) -> str: - """Generate a JavaScript ``fetch`` snippet.""" - hdr = parse_header_dict(headers) - opts: list[str] = [f' method: "{method.upper()}"'] - if hdr: - hdr_str = json.dumps(hdr, indent=4) - # Indent the headers block - hdr_lines = hdr_str.splitlines() - indented = "\n".join( - f" {line}" if i > 0 else f" headers: {line}" for i, line in enumerate(hdr_lines) - ) - opts.append(indented) - if body: - opts.append(f" body: {json.dumps(body)}") - opts_block = ",\n".join(opts) - return ( - f'fetch("{url}", {{\n' - f"{opts_block}\n" - f"}})\n" - f" .then(response => response.json())\n" - f" .then(data => console.log(data));" - ) - - @staticmethod - def available_languages() -> list[str]: - """Return the list of supported snippet languages.""" - return ["cURL", "Python (requests)", "JavaScript (fetch)"] - - @staticmethod - def generate( - language: str, - *, - method: str, - url: str, - headers: str | None = None, - body: str | None = None, - ) -> str: - """Generate a snippet for the given language label. - - The *language* parameter should be one of the values returned by - :meth:`available_languages`. - """ - if language == "cURL": - return SnippetGenerator.curl(method=method, url=url, headers=headers, body=body) - if language == "Python (requests)": - return SnippetGenerator.python_requests( - method=method, url=url, headers=headers, body=body - ) - if language == "JavaScript (fetch)": - return SnippetGenerator.javascript_fetch( - method=method, url=url, headers=headers, body=body - ) - return f"# Unsupported language: {language}" diff --git a/src/services/http/snippet_generator/__init__.py b/src/services/http/snippet_generator/__init__.py new file mode 100644 index 0000000..3798ef8 --- /dev/null +++ b/src/services/http/snippet_generator/__init__.py @@ -0,0 +1,21 @@ +"""Code snippet generation sub-package. + +Re-exports the public API so existing imports continue to work:: + + from services.http.snippet_generator import SnippetGenerator + from services.http.snippet_generator import SnippetOptions +""" + +from __future__ import annotations + +from services.http.snippet_generator.generator import ( + LanguageEntry, + SnippetGenerator, + SnippetOptions, +) + +__all__ = [ + "LanguageEntry", + "SnippetGenerator", + "SnippetOptions", +] diff --git a/src/services/http/snippet_generator/compiled_snippets.py b/src/services/http/snippet_generator/compiled_snippets.py new file mode 100644 index 0000000..a281a0f --- /dev/null +++ b/src/services/http/snippet_generator/compiled_snippets.py @@ -0,0 +1,573 @@ +"""Snippet generators for compiled / statically-typed languages. + +Generators for Go, Rust, C (libcurl), Swift, Java, Kotlin, and C#. +""" + +from __future__ import annotations + +import json + +from services.http.snippet_generator.generator import LanguageEntry, SnippetOptions, indent_str + +# --------------------------------------------------------------------------- +# Go +# --------------------------------------------------------------------------- + + +def go_native( + *, + method: str, + url: str, + headers: dict[str, str], + body: str | None = None, + options: SnippetOptions | None = None, +) -> str: + """Generate a Go ``net/http`` snippet.""" + options = options or {} + ind = indent_str(options) + timeout = options.get("request_timeout", 0) + + lines = ["package main", "", "import ("] + imports = ['"fmt"', '"io"', '"net/http"'] + if body: + imports.append('"strings"') + if timeout: + imports.append('"time"') + for imp in sorted(imports): + lines.append(f"{ind}{imp}") + lines.append(")") + lines.append("") + lines.append("func main() {") + + if body: + lines.append(f"{ind}body := strings.NewReader(`{body}`)") + lines.append(f'{ind}req, err := http.NewRequest("{method.upper()}", "{url}", body)') + else: + lines.append(f'{ind}req, err := http.NewRequest("{method.upper()}", "{url}", nil)') + lines.append(f"{ind}if err != nil {{") + lines.append(f"{ind}{ind}panic(err)") + lines.append(f"{ind}}}") + + for key, value in headers.items(): + lines.append(f'{ind}req.Header.Set("{key}", "{value}")') + + lines.append("") + no_redirect = not options.get("follow_redirect", True) + if timeout and no_redirect: + lines.append(f"{ind}client := &http.Client{{") + lines.append(f"{ind}{ind}Timeout: {timeout} * time.Second,") + lines.append( + f"{ind}{ind}CheckRedirect: func(req *http.Request, via []*http.Request) error {{" + ) + lines.append(f"{ind}{ind}{ind}return http.ErrUseLastResponse") + lines.append(f"{ind}{ind}}},") + lines.append(f"{ind}}}") + elif timeout: + lines.append(f"{ind}client := &http.Client{{Timeout: {timeout} * time.Second}}") + elif no_redirect: + lines.append(f"{ind}client := &http.Client{{") + lines.append( + f"{ind}{ind}CheckRedirect: func(req *http.Request, via []*http.Request) error {{" + ) + lines.append(f"{ind}{ind}{ind}return http.ErrUseLastResponse") + lines.append(f"{ind}{ind}}},") + lines.append(f"{ind}}}") + else: + lines.append(f"{ind}client := &http.Client{{}}") + lines.append(f"{ind}resp, err := client.Do(req)") + lines.append(f"{ind}if err != nil {{") + lines.append(f"{ind}{ind}panic(err)") + lines.append(f"{ind}}}") + lines.append(f"{ind}defer resp.Body.Close()") + lines.append("") + lines.append(f"{ind}respBody, _ := io.ReadAll(resp.Body)") + lines.append(f"{ind}fmt.Println(string(respBody))") + lines.append("}") + return "\n".join(lines) + + +# --------------------------------------------------------------------------- +# Rust +# --------------------------------------------------------------------------- + + +def rust_reqwest( + *, + method: str, + url: str, + headers: dict[str, str], + body: str | None = None, + options: SnippetOptions | None = None, +) -> str: + """Generate a Rust ``reqwest`` snippet.""" + options = options or {} + ind = indent_str(options) + timeout = options.get("request_timeout", 0) + + lines = ["use reqwest;", ""] + lines.append("#[tokio::main]") + lines.append("async fn main() -> Result<(), reqwest::Error> {") + + no_redirect = not options.get("follow_redirect", True) + if timeout or no_redirect: + lines.append(f"{ind}let client = reqwest::Client::builder()") + if timeout: + lines.append(f"{ind}{ind}.timeout(std::time::Duration::from_secs({timeout}))") + if no_redirect: + lines.append(f"{ind}{ind}.redirect(reqwest::redirect::Policy::none())") + lines.append(f"{ind}{ind}.build()?;") + else: + lines.append(f"{ind}let client = reqwest::Client::new();") + lines.append("") + + method_lower = method.lower() + lines.append(f'{ind}let response = client.{method_lower}("{url}")') + for key, value in headers.items(): + lines.append(f'{ind}{ind}.header("{key}", "{value}")') + if body: + lines.append(f"{ind}{ind}.body({json.dumps(body)})") + lines.append(f"{ind}{ind}.send()") + lines.append(f"{ind}{ind}.await?;") + lines.append("") + lines.append(f"{ind}let body = response.text().await?;") + lines.append(f'{ind}println!("{{}}", body);') + lines.append(f"{ind}Ok(())") + lines.append("}") + return "\n".join(lines) + + +# --------------------------------------------------------------------------- +# C (libcurl) +# --------------------------------------------------------------------------- + + +def c_libcurl( + *, + method: str, + url: str, + headers: dict[str, str], + body: str | None = None, + options: SnippetOptions | None = None, +) -> str: + """Generate a C libcurl snippet.""" + options = options or {} + ind = indent_str(options) + timeout = options.get("request_timeout", 0) + boilerplate = options.get("include_boilerplate", True) + + lines: list[str] = [] + if boilerplate: + lines.extend(["#include ", "#include ", ""]) + lines.append("int main(void) {") + lines.append(f"{ind}CURL *curl = curl_easy_init();") + lines.append(f"{ind}if (curl) {{") + ind2 = ind * 2 + else: + lines.append("CURL *curl = curl_easy_init();") + lines.append("if (curl) {") + ind2 = ind + + lines.append(f'{ind2}curl_easy_setopt(curl, CURLOPT_URL, "{url}");') + lines.append(f'{ind2}curl_easy_setopt(curl, CURLOPT_CUSTOMREQUEST, "{method.upper()}");') + if options.get("follow_redirect", True): + lines.append(f"{ind2}curl_easy_setopt(curl, CURLOPT_FOLLOWLOCATION, 1L);") + if timeout: + lines.append(f"{ind2}curl_easy_setopt(curl, CURLOPT_TIMEOUT, {timeout}L);") + + if headers: + lines.append("") + lines.append(f"{ind2}struct curl_slist *headers = NULL;") + for key, value in headers.items(): + lines.append(f'{ind2}headers = curl_slist_append(headers, "{key}: {value}");') + lines.append(f"{ind2}curl_easy_setopt(curl, CURLOPT_HTTPHEADER, headers);") + + if body: + escaped = body.replace('"', '\\"') + lines.append(f'{ind2}curl_easy_setopt(curl, CURLOPT_POSTFIELDS, "{escaped}");') + + lines.append("") + lines.append(f"{ind2}CURLcode res = curl_easy_perform(curl);") + if headers: + lines.append(f"{ind2}curl_slist_free_all(headers);") + lines.append(f"{ind2}curl_easy_cleanup(curl);") + + if boilerplate: + lines.append(f"{ind}}}") + lines.append(f"{ind}return 0;") + lines.append("}") + else: + lines.append("}") + return "\n".join(lines) + + +# --------------------------------------------------------------------------- +# Swift +# --------------------------------------------------------------------------- + + +def swift_urlsession( + *, + method: str, + url: str, + headers: dict[str, str], + body: str | None = None, + options: SnippetOptions | None = None, +) -> str: + """Generate a Swift ``URLSession`` snippet.""" + options = options or {} + ind = indent_str(options) + timeout = options.get("request_timeout", 0) + boilerplate = options.get("include_boilerplate", True) + + lines: list[str] = [] + if boilerplate: + lines.extend(["import Foundation", ""]) + lines.append(f'let url = URL(string: "{url}")!') + lines.append("var request = URLRequest(url: url)") + lines.append(f'request.httpMethod = "{method.upper()}"') + if timeout: + lines.append(f"request.timeoutInterval = {timeout}") + for key, value in headers.items(): + lines.append(f'request.setValue("{value}", forHTTPHeaderField: "{key}")') + if body: + escaped = body.replace('"', '\\"') + lines.append(f'request.httpBody = "{escaped}".data(using: .utf8)') + lines.append("") + lines.append("let task = URLSession.shared.dataTask(with: request) { data, response, error in") + lines.append(f"{ind}if let data = data {{") + lines.append(f'{ind}{ind}print(String(data: data, encoding: .utf8) ?? "")') + lines.append(f"{ind}}}") + lines.append("}") + lines.append("task.resume()") + return "\n".join(lines) + + +# --------------------------------------------------------------------------- +# Java +# --------------------------------------------------------------------------- + + +def java_okhttp( + *, + method: str, + url: str, + headers: dict[str, str], + body: str | None = None, + options: SnippetOptions | None = None, +) -> str: + """Generate a Java OkHttp snippet.""" + options = options or {} + ind = indent_str(options) + timeout = options.get("request_timeout", 0) + boilerplate = options.get("include_boilerplate", True) + + lines: list[str] = [] + if boilerplate: + lines.extend( + [ + "import okhttp3.*;", + "", + "public class Main {", + f"{ind}public static void main(String[] args) throws Exception {{", + ] + ) + ind2 = ind * 2 + else: + ind2 = ind + + if timeout: + lines.append(f"{ind2}OkHttpClient client = new OkHttpClient.Builder()") + lines.append( + f"{ind2}{ind}.connectTimeout({timeout}, java.util.concurrent.TimeUnit.SECONDS)" + ) + lines.append(f"{ind2}{ind}.readTimeout({timeout}, java.util.concurrent.TimeUnit.SECONDS)") + if not options.get("follow_redirect", True): + lines.append(f"{ind2}{ind}.followRedirects(false)") + lines.append(f"{ind2}{ind}.build();") + elif not options.get("follow_redirect", True): + lines.append(f"{ind2}OkHttpClient client = new OkHttpClient.Builder()") + lines.append(f"{ind2}{ind}.followRedirects(false)") + lines.append(f"{ind2}{ind}.build();") + else: + lines.append(f"{ind2}OkHttpClient client = new OkHttpClient();") + lines.append("") + + method_upper = method.upper() + needs_body = method_upper in ("POST", "PUT", "PATCH") + if needs_body and body: + ct = headers.get("Content-Type", "application/json") + escaped = body.replace('"', '\\"') + lines.append( + f'{ind2}RequestBody body = RequestBody.create("{escaped}", MediaType.parse("{ct}"));' + ) + elif needs_body: + lines.append( + f'{ind2}RequestBody body = RequestBody.create("", MediaType.parse("application/json"));' + ) + + lines.append(f"{ind2}Request request = new Request.Builder()") + lines.append(f'{ind2}{ind}.url("{url}")') + for key, value in headers.items(): + lines.append(f'{ind2}{ind}.addHeader("{key}", "{value}")') + if needs_body: + lines.append(f"{ind2}{ind}.{method_lower_name(method_upper)}(body)") + else: + lines.append(f"{ind2}{ind}.{method_lower_name(method_upper)}()") + lines.append(f"{ind2}{ind}.build();") + lines.append("") + lines.append(f"{ind2}Response response = client.newCall(request).execute();") + lines.append(f"{ind2}System.out.println(response.body().string());") + if boilerplate: + lines.append(f"{ind}}}") + lines.append("}") + return "\n".join(lines) + + +def method_lower_name(method: str) -> str: + """Map HTTP method to OkHttp builder method name.""" + mapping = { + "GET": "get", + "POST": "post", + "PUT": "put", + "PATCH": "patch", + "DELETE": "delete", + "HEAD": "head", + } + return mapping.get(method, "get") + + +# --------------------------------------------------------------------------- +# Kotlin +# --------------------------------------------------------------------------- + + +def kotlin_okhttp( + *, + method: str, + url: str, + headers: dict[str, str], + body: str | None = None, + options: SnippetOptions | None = None, +) -> str: + """Generate a Kotlin OkHttp snippet.""" + options = options or {} + ind = indent_str(options) + timeout = options.get("request_timeout", 0) + boilerplate = options.get("include_boilerplate", True) + + lines: list[str] = [] + if boilerplate: + lines.extend( + [ + "import okhttp3.*", + "import okhttp3.MediaType.Companion.toMediaType", + "import okhttp3.RequestBody.Companion.toRequestBody", + "", + "fun main() {", + ] + ) + + if timeout: + lines.append(f"{ind}val client = OkHttpClient.Builder()") + lines.append(f"{ind}{ind}.connectTimeout({timeout}, java.util.concurrent.TimeUnit.SECONDS)") + lines.append(f"{ind}{ind}.readTimeout({timeout}, java.util.concurrent.TimeUnit.SECONDS)") + if not options.get("follow_redirect", True): + lines.append(f"{ind}{ind}.followRedirects(false)") + lines.append(f"{ind}{ind}.build()") + elif not options.get("follow_redirect", True): + lines.append(f"{ind}val client = OkHttpClient.Builder()") + lines.append(f"{ind}{ind}.followRedirects(false)") + lines.append(f"{ind}{ind}.build()") + else: + lines.append(f"{ind}val client = OkHttpClient()") + lines.append("") + + method_upper = method.upper() + needs_body = method_upper in ("POST", "PUT", "PATCH") + if needs_body and body: + ct = headers.get("Content-Type", "application/json") + escaped = body.replace('"', '\\"') + lines.append(f'{ind}val body = "{escaped}".toRequestBody("{ct}".toMediaType())') + elif needs_body: + lines.append(f'{ind}val body = "".toRequestBody("application/json".toMediaType())') + + lines.append(f"{ind}val request = Request.Builder()") + lines.append(f'{ind}{ind}.url("{url}")') + for key, value in headers.items(): + lines.append(f'{ind}{ind}.addHeader("{key}", "{value}")') + if needs_body: + lines.append(f"{ind}{ind}.{method_lower_name(method_upper)}(body)") + else: + lines.append(f"{ind}{ind}.{method_lower_name(method_upper)}()") + lines.append(f"{ind}{ind}.build()") + lines.append("") + lines.append(f"{ind}val response = client.newCall(request).execute()") + lines.append(f'{ind}println(response.body?.string() ?: "")') + if boilerplate: + lines.append("}") + return "\n".join(lines) + + +# --------------------------------------------------------------------------- +# C# +# --------------------------------------------------------------------------- + + +def csharp_httpclient( + *, + method: str, + url: str, + headers: dict[str, str], + body: str | None = None, + options: SnippetOptions | None = None, +) -> str: + """Generate a C# ``HttpClient`` snippet.""" + options = options or {} + timeout = options.get("request_timeout", 0) + boilerplate = options.get("include_boilerplate", True) + + lines: list[str] = [] + if boilerplate: + lines.append("using System.Net.Http;") + lines.append("") + no_redirect = not options.get("follow_redirect", True) + if no_redirect: + lines.append("var handler = new HttpClientHandler { AllowAutoRedirect = false };") + lines.append("var client = new HttpClient(handler);") + else: + lines.append("var client = new HttpClient();") + if timeout: + lines.append(f"client.Timeout = TimeSpan.FromSeconds({timeout});") + + method_map = { + "GET": "HttpMethod.Get", + "POST": "HttpMethod.Post", + "PUT": "HttpMethod.Put", + "PATCH": "HttpMethod.Patch", + "DELETE": "HttpMethod.Delete", + "HEAD": "HttpMethod.Head", + "OPTIONS": "HttpMethod.Options", + } + http_method = method_map.get(method.upper(), "HttpMethod.Get") + + lines.append("") + lines.append(f'var request = new HttpRequestMessage({http_method}, "{url}");') + for key, value in headers.items(): + if key.lower() == "content-type": + continue # Content-Type is set on content, not request headers + lines.append(f'request.Headers.Add("{key}", "{value}");') + + if body: + ct = headers.get("Content-Type", "application/json") + escaped = body.replace('"', '\\"') + lines.append( + f'request.Content = new StringContent("{escaped}", System.Text.Encoding.UTF8, "{ct}");' + ) + lines.append("") + lines.append("var response = await client.SendAsync(request);") + lines.append("var content = await response.Content.ReadAsStringAsync();") + lines.append("Console.WriteLine(content);") + return "\n".join(lines) + + +def csharp_restsharp( + *, + method: str, + url: str, + headers: dict[str, str], + body: str | None = None, + options: SnippetOptions | None = None, +) -> str: + """Generate a C# ``RestSharp`` snippet.""" + options = options or {} + timeout = options.get("request_timeout", 0) + boilerplate = options.get("include_boilerplate", True) + + method_map = { + "GET": "Method.Get", + "POST": "Method.Post", + "PUT": "Method.Put", + "PATCH": "Method.Patch", + "DELETE": "Method.Delete", + "HEAD": "Method.Head", + "OPTIONS": "Method.Options", + } + rest_method = method_map.get(method.upper(), "Method.Get") + + lines: list[str] = [] + if boilerplate: + lines.append("using RestSharp;") + lines.append("") + opts_parts: list[str] = [] + if timeout: + opts_parts.append(f"MaxTimeout = {timeout * 1000}") + if not options.get("follow_redirect", True): + opts_parts.append("FollowRedirects = false") + if opts_parts: + opts_str = ", ".join(opts_parts) + lines.append( + f'var client = new RestClient(new RestClientOptions("{url}") {{ {opts_str} }});' + ) + else: + lines.append(f'var client = new RestClient("{url}");') + + lines.append(f'var request = new RestRequest("", {rest_method});') + for key, value in headers.items(): + lines.append(f'request.AddHeader("{key}", "{value}");') + if body: + ct = headers.get("Content-Type", "application/json") + escaped = body.replace('"', '\\"') + lines.append(f'request.AddStringBody("{escaped}", "{ct}");') + lines.append("") + lines.append("var response = await client.ExecuteAsync(request);") + lines.append("Console.WriteLine(response.Content);") + return "\n".join(lines) + + +# --------------------------------------------------------------------------- +# Registry entries +# --------------------------------------------------------------------------- + +# Per-language option tuples +_OPT_STD = ("indent_count", "indent_type", "trim_body", "request_timeout", "follow_redirect") + +COMPILED_LANGUAGES: list[LanguageEntry] = [ + LanguageEntry( + "C (libcurl)", + "c", + (*_OPT_STD, "include_boilerplate"), + c_libcurl, + ), + LanguageEntry( + "C# (HttpClient)", + "csharp", + (*_OPT_STD, "include_boilerplate"), + csharp_httpclient, + ), + LanguageEntry( + "C# (RestSharp)", + "csharp", + (*_OPT_STD, "include_boilerplate"), + csharp_restsharp, + ), + LanguageEntry("Go (net/http)", "go", _OPT_STD, go_native), + LanguageEntry( + "Java (OkHttp)", + "java", + (*_OPT_STD, "include_boilerplate"), + java_okhttp, + ), + LanguageEntry( + "Kotlin (OkHttp)", + "kotlin", + (*_OPT_STD, "include_boilerplate"), + kotlin_okhttp, + ), + LanguageEntry("Rust (reqwest)", "rust", _OPT_STD, rust_reqwest), + LanguageEntry( + "Swift (URLSession)", + "swift", + ("indent_count", "indent_type", "trim_body", "request_timeout", "include_boilerplate"), + swift_urlsession, + ), +] diff --git a/src/services/http/snippet_generator/dynamic_snippets.py b/src/services/http/snippet_generator/dynamic_snippets.py new file mode 100644 index 0000000..a951e79 --- /dev/null +++ b/src/services/http/snippet_generator/dynamic_snippets.py @@ -0,0 +1,555 @@ +"""Snippet generators for interpreted / dynamic languages. + +Generators for Python, JavaScript, Node.js, Ruby, PHP, and Dart. +""" + +from __future__ import annotations + +import json + +from services.http.snippet_generator.generator import LanguageEntry, SnippetOptions, indent_str + +# --------------------------------------------------------------------------- +# Python +# --------------------------------------------------------------------------- + + +def python_requests( + *, + method: str, + url: str, + headers: dict[str, str], + body: str | None = None, + options: SnippetOptions | None = None, +) -> str: + """Generate a Python ``requests`` snippet.""" + options = options or {} + ind = indent_str(options) + lines = ["import requests", ""] + + if headers: + lines.append(f"headers = {json.dumps(headers, indent=len(ind))}") + lines.append("") + + call = f'response = requests.{method.lower()}("{url}"' + if headers: + call += ", headers=headers" + if body: + try: + json.loads(body) + lines.append(f"payload = {body}") + lines.append("") + call += ", json=payload" + except (json.JSONDecodeError, TypeError): + call += f', data="""{body}"""' + + timeout = options.get("request_timeout", 0) + if timeout: + call += f", timeout={timeout}" + if not options.get("follow_redirect", True): + call += ", allow_redirects=False" + call += ")" + lines.append(call) + lines.append("print(response.status_code)") + lines.append("print(response.text)") + return "\n".join(lines) + + +def python_http_client( + *, + method: str, + url: str, + headers: dict[str, str], + body: str | None = None, + options: SnippetOptions | None = None, +) -> str: + """Generate a Python ``http.client`` snippet.""" + from urllib.parse import urlparse + + options = options or {} + ind = indent_str(options) + parsed = urlparse(url) + host = parsed.hostname or "" + path = parsed.path or "/" + if parsed.query: + path = f"{path}?{parsed.query}" + use_https = parsed.scheme == "https" + module = "http.client" + + lines = [f"import {module}", ""] + conn_class = "HTTPSConnection" if use_https else "HTTPConnection" + timeout = options.get("request_timeout", 0) + conn_args = f'"{host}"' + if parsed.port: + conn_args += f", {parsed.port}" + if timeout: + conn_args += f", timeout={timeout}" + lines.append(f"conn = {module}.{conn_class}({conn_args})") + lines.append("") + + if headers: + lines.append(f"headers = {json.dumps(headers, indent=len(ind))}") + lines.append("") + + body_arg = f'"{body}"' if body else "None" + hdr_arg = "headers" if headers else "{}" + lines.append(f'conn.request("{method.upper()}", "{path}", body={body_arg}, headers={hdr_arg})') + lines.append("res = conn.getresponse()") + lines.append("print(res.status, res.reason)") + lines.append("print(res.read().decode())") + return "\n".join(lines) + + +# --------------------------------------------------------------------------- +# JavaScript / Browser +# --------------------------------------------------------------------------- + + +def javascript_fetch( + *, + method: str, + url: str, + headers: dict[str, str], + body: str | None = None, + options: SnippetOptions | None = None, +) -> str: + """Generate a JavaScript ``fetch`` snippet.""" + options = options or {} + ind = indent_str(options) + use_async = options.get("async_await", False) + opts_lines: list[str] = [f'{ind}method: "{method.upper()}"'] + if headers: + hdr_str = json.dumps(headers, indent=len(ind)) + hdr_lines = hdr_str.splitlines() + indented = "\n".join( + f"{ind}{ind}{line}" if i > 0 else f"{ind}headers: {line}" + for i, line in enumerate(hdr_lines) + ) + opts_lines.append(indented) + if body: + opts_lines.append(f"{ind}body: {json.dumps(body)}") + if not options.get("follow_redirect", True): + opts_lines.append(f'{ind}redirect: "manual"') + opts_block = ",\n".join(opts_lines) + if use_async: + lines = [ + "async function makeRequest() {", + f'{ind}const response = await fetch("{url}", {{', + ] + for line in opts_block.splitlines(): + lines.append(f"{ind}{line}") + lines.append(f"{ind}}});") + lines.append(f"{ind}const data = await response.json();") + lines.append(f"{ind}console.log(data);") + lines.append("}") + lines.append("") + lines.append("makeRequest();") + return "\n".join(lines) + return ( + f'fetch("{url}", {{\n' + f"{opts_block}\n" + f"}})\n" + f"{ind}.then(response => response.json())\n" + f"{ind}.then(data => console.log(data));" + ) + + +def javascript_xhr( + *, + method: str, + url: str, + headers: dict[str, str], + body: str | None = None, + options: SnippetOptions | None = None, +) -> str: + """Generate a JavaScript ``XMLHttpRequest`` snippet.""" + options = options or {} + ind = indent_str(options) + timeout = options.get("request_timeout", 0) + lines = ["var xhr = new XMLHttpRequest();"] + lines.append(f'xhr.open("{method.upper()}", "{url}");') + if timeout: + lines.append(f"xhr.timeout = {timeout * 1000};") + for key, value in headers.items(): + lines.append(f'xhr.setRequestHeader("{key}", "{value}");') + lines.append("") + lines.append("xhr.onreadystatechange = function () {") + lines.append(f"{ind}if (xhr.readyState === 4) {{") + lines.append(f"{ind}{ind}console.log(xhr.status);") + lines.append(f"{ind}{ind}console.log(xhr.responseText);") + lines.append(f"{ind}}}") + lines.append("};") + lines.append("") + if body: + lines.append(f"xhr.send({json.dumps(body)});") + else: + lines.append("xhr.send();") + return "\n".join(lines) + + +# --------------------------------------------------------------------------- +# Node.js +# --------------------------------------------------------------------------- + + +def nodejs_axios( + *, + method: str, + url: str, + headers: dict[str, str], + body: str | None = None, + options: SnippetOptions | None = None, +) -> str: + """Generate a Node.js ``axios`` snippet.""" + options = options or {} + ind = indent_str(options) + timeout = options.get("request_timeout", 0) + use_async = options.get("async_await", False) + + lines = ['const axios = require("axios");', ""] + config_parts: list[str] = [ + f'{ind}method: "{method.lower()}"', + f'{ind}url: "{url}"', + ] + if headers: + hdr_str = json.dumps(headers, indent=len(ind)) + hdr_lines = hdr_str.splitlines() + indented = "\n".join( + f"{ind}{ind}{line}" if i > 0 else f"{ind}headers: {line}" + for i, line in enumerate(hdr_lines) + ) + config_parts.append(indented) + if body: + try: + json.loads(body) + config_parts.append(f"{ind}data: {body}") + except (json.JSONDecodeError, TypeError): + config_parts.append(f"{ind}data: {json.dumps(body)}") + if timeout: + config_parts.append(f"{ind}timeout: {timeout * 1000}") + if not options.get("follow_redirect", True): + config_parts.append(f"{ind}maxRedirects: 0") + config_block = ",\n".join(config_parts) + + if use_async: + lines.append("async function makeRequest() {") + lines.append(f"{ind}const response = await axios({{") + for line in config_block.splitlines(): + lines.append(f"{ind}{line}") + lines.append(f"{ind}}});") + lines.append(f"{ind}console.log(response.data);") + lines.append("}") + lines.append("") + lines.append("makeRequest();") + else: + lines.append(f"axios({{\n{config_block}\n}})") + lines.append(f"{ind}.then(response => console.log(response.data))") + lines.append(f"{ind}.catch(error => console.error(error));") + return "\n".join(lines) + + +def nodejs_native( + *, + method: str, + url: str, + headers: dict[str, str], + body: str | None = None, + options: SnippetOptions | None = None, +) -> str: + """Generate a Node.js native ``http``/``https`` snippet.""" + from urllib.parse import urlparse + + options = options or {} + ind = indent_str(options) + timeout = options.get("request_timeout", 0) + use_es6 = options.get("es6_features", False) + + parsed = urlparse(url) + use_https = parsed.scheme == "https" + mod = "https" if use_https else "http" + + if use_es6: + lines = [f'import {mod} from "{mod}";', ""] + else: + lines = [f'const {mod} = require("{mod}");', ""] + + opt_parts: list[str] = [] + if parsed.hostname: + opt_parts.append(f'{ind}hostname: "{parsed.hostname}"') + port = parsed.port or (443 if use_https else 80) + opt_parts.append(f"{ind}port: {port}") + path = parsed.path or "/" + if parsed.query: + path = f"{path}?{parsed.query}" + opt_parts.append(f'{ind}path: "{path}"') + opt_parts.append(f'{ind}method: "{method.upper()}"') + if headers: + hdr_str = json.dumps(headers, indent=len(ind)) + hdr_lines = hdr_str.splitlines() + indented = "\n".join( + f"{ind}{ind}{line}" if i > 0 else f"{ind}headers: {line}" + for i, line in enumerate(hdr_lines) + ) + opt_parts.append(indented) + if timeout: + opt_parts.append(f"{ind}timeout: {timeout * 1000}") + opts_block = ",\n".join(opt_parts) + + lines.append(f"const options = {{\n{opts_block}\n}};") + lines.append("") + cb = "(res) => {" + lines.append(f"const req = {mod}.request(options, {cb}") + lines.append(f'{ind}let data = "";') + lines.append(f'{ind}res.on("data", (chunk) => {{ data += chunk; }});') + lines.append(f'{ind}res.on("end", () => {{ console.log(data); }});') + lines.append("});") + lines.append("") + lines.append('req.on("error", (error) => { console.error(error); });') + if body: + lines.append(f"req.write({json.dumps(body)});") + lines.append("req.end();") + return "\n".join(lines) + + +# --------------------------------------------------------------------------- +# Ruby +# --------------------------------------------------------------------------- + + +def ruby_nethttp( + *, + method: str, + url: str, + headers: dict[str, str], + body: str | None = None, + options: SnippetOptions | None = None, +) -> str: + """Generate a Ruby ``Net::HTTP`` snippet.""" + from urllib.parse import urlparse + + options = options or {} + timeout = options.get("request_timeout", 0) + + parsed = urlparse(url) + + lines = ['require "net/http"', 'require "uri"', ""] + lines.append(f'uri = URI.parse("{url}")') + lines.append("") + + method_upper = method.upper() + method_map = { + "GET": "Get", + "POST": "Post", + "PUT": "Put", + "PATCH": "Patch", + "DELETE": "Delete", + "HEAD": "Head", + "OPTIONS": "Options", + } + rb_method = method_map.get(method_upper, "Get") + + lines.append(f"request = Net::HTTP::{rb_method}.new(uri)") + for key, value in headers.items(): + lines.append(f'request["{key}"] = "{value}"') + if body: + lines.append(f"request.body = '{body}'") + lines.append("") + + lines.append("http = Net::HTTP.new(uri.host, uri.port)") + if parsed.scheme == "https": + lines.append("http.use_ssl = true") + if timeout: + lines.append(f"http.read_timeout = {timeout}") + lines.append(f"http.open_timeout = {timeout}") + if not options.get("follow_redirect", True): + lines.append("http.max_retries = 0") + lines.append("") + lines.append("response = http.request(request)") + lines.append("puts response.code") + lines.append("puts response.body") + return "\n".join(lines) + + +# --------------------------------------------------------------------------- +# PHP +# --------------------------------------------------------------------------- + + +def php_curl( + *, + method: str, + url: str, + headers: dict[str, str], + body: str | None = None, + options: SnippetOptions | None = None, +) -> str: + """Generate a PHP cURL snippet.""" + options = options or {} + ind = indent_str(options) + timeout = options.get("request_timeout", 0) + + lines = [" str: + """Generate a PHP Guzzle snippet.""" + options = options or {} + ind = indent_str(options) + timeout = options.get("request_timeout", 0) + + lines = [" '{v}'" for k, v in headers.items()) + opts_parts.append(f"{ind}'headers' => [{hdr_items}]") + if body: + try: + json.loads(body) + opts_parts.append(f"{ind}'json' => json_decode('{body}', true)") + except (json.JSONDecodeError, TypeError): + escaped = body.replace("'", "\\'") + opts_parts.append(f"{ind}'body' => '{escaped}'") + if timeout: + opts_parts.append(f"{ind}'timeout' => {timeout}") + if not options.get("follow_redirect", True): + opts_parts.append(f"{ind}'allow_redirects' => false") + opts_block = ",\n".join(opts_parts) + + lines.append(f'$response = $client->request("{method.upper()}", "{url}", [') + if opts_block: + lines.append(opts_block) + lines.append("]);") + lines.append("") + lines.append("echo $response->getBody();") + return "\n".join(lines) + + +# --------------------------------------------------------------------------- +# Dart +# --------------------------------------------------------------------------- + + +def dart_http( + *, + method: str, + url: str, + headers: dict[str, str], + body: str | None = None, + options: SnippetOptions | None = None, +) -> str: + """Generate a Dart ``http`` package snippet.""" + options = options or {} + ind = indent_str(options) + boilerplate = options.get("include_boilerplate", True) + + lines: list[str] = [] + if boilerplate: + lines.append("import 'package:http/http.dart' as http;") + lines.append("") + lines.append("void main() async {") + + url_line = f"{ind}var url = Uri.parse('{url}');" + lines.append(url_line) + + if headers: + hdr_items = ", ".join(f"'{k}': '{v}'" for k, v in headers.items()) + lines.append(f"{ind}var headers = {{{hdr_items}}};") + + method_lower = method.lower() + method_map = {"get", "post", "put", "patch", "delete", "head"} + fn = method_lower if method_lower in method_map else "get" + + call = f"{ind}var response = await http.{fn}(url" + if headers: + call += ", headers: headers" + if body and fn in ("post", "put", "patch"): + call += f", body: '{body}'" + call += ");" + lines.append(call) + + lines.append(f"{ind}print(response.statusCode);") + lines.append(f"{ind}print(response.body);") + lines.append("}") + return "\n".join(lines) + + +# --------------------------------------------------------------------------- +# Registry entries +# --------------------------------------------------------------------------- + +# Per-language option tuples +_OPT_STD = ("indent_count", "indent_type", "trim_body", "request_timeout", "follow_redirect") + +DYNAMIC_LANGUAGES: list[LanguageEntry] = [ + LanguageEntry( + "Dart (http)", + "dart", + (*_OPT_STD, "include_boilerplate"), + dart_http, + ), + LanguageEntry( + "JavaScript (fetch)", + "javascript", + (*_OPT_STD, "async_await"), + javascript_fetch, + ), + LanguageEntry( + "JavaScript (XHR)", + "javascript", + ("indent_count", "indent_type", "trim_body", "request_timeout"), + javascript_xhr, + ), + LanguageEntry( + "NodeJS (Axios)", + "javascript", + (*_OPT_STD, "async_await"), + nodejs_axios, + ), + LanguageEntry( + "NodeJS (Native)", + "javascript", + (*_OPT_STD, "es6_features"), + nodejs_native, + ), + LanguageEntry("PHP (cURL)", "php", _OPT_STD, php_curl), + LanguageEntry("PHP (Guzzle)", "php", _OPT_STD, php_guzzle), + LanguageEntry( + "Python (http.client)", + "python", + ("indent_count", "indent_type", "trim_body", "request_timeout"), + python_http_client, + ), + LanguageEntry("Python (requests)", "python", _OPT_STD, python_requests), + LanguageEntry("Ruby (Net::HTTP)", "ruby", _OPT_STD, ruby_nethttp), +] diff --git a/src/services/http/snippet_generator/generator.py b/src/services/http/snippet_generator/generator.py new file mode 100644 index 0000000..b4ae56f --- /dev/null +++ b/src/services/http/snippet_generator/generator.py @@ -0,0 +1,190 @@ +"""Core registry, dispatch, and shared helpers for snippet generation. + +Defines :class:`SnippetGenerator` (the public API), :class:`SnippetOptions` +(per-snippet configuration), :class:`LanguageEntry` (registry metadata), +and delegates auth injection to :func:`services.http.auth_handler.apply_auth`. +""" + +from __future__ import annotations + +from collections.abc import Callable +from typing import NamedTuple, TypedDict + +from services.http.auth_handler import apply_auth +from services.http.header_utils import parse_header_dict + +# --------------------------------------------------------------------------- +# Options +# --------------------------------------------------------------------------- + +_DEFAULT_INDENT_COUNT = 2 +_DEFAULT_INDENT_TYPE = "space" + + +class SnippetOptions(TypedDict, total=False): + """Per-snippet configuration options. + + All fields are optional — missing keys fall back to defaults. + """ + + indent_count: int + """Number of indentation characters per level (default 2).""" + indent_type: str + """``"space"`` or ``"tab"`` (default ``"space"``).""" + trim_body: bool + """Strip leading/trailing whitespace from the request body (default False).""" + follow_redirect: bool + """Include a follow-redirects flag in shell commands (default True).""" + request_timeout: int + """Timeout in seconds; 0 means no timeout (default 0).""" + include_boilerplate: bool + """Include boilerplate code such as imports and main wrappers (default True).""" + async_await: bool + """Use async/await syntax instead of promise chains (default False).""" + es6_features: bool + """Use ES6+ syntax such as ``import`` and arrow functions (default False).""" + multiline: bool + """Split shell commands across multiple lines (default True).""" + long_form: bool + """Use long-form options like ``--header`` instead of ``-H`` (default True).""" + line_continuation: str + """Line continuation char: ``\\``, ``^``, or backtick (default ``\\``).""" + quote_type: str + """``"single"`` or ``"double"`` quotes around URLs (default ``"single"``).""" + follow_original_method: bool + """Keep original HTTP method on redirect instead of GET (default False).""" + silent_mode: bool + """Suppress progress meter / error messages (default False).""" + + +def resolve_options(options: SnippetOptions | None) -> SnippetOptions: + """Return *options* with defaults merged for any missing keys.""" + defaults: SnippetOptions = { + "indent_count": _DEFAULT_INDENT_COUNT, + "indent_type": _DEFAULT_INDENT_TYPE, + "trim_body": False, + "follow_redirect": True, + "request_timeout": 0, + "include_boilerplate": True, + "async_await": False, + "es6_features": False, + "multiline": True, + "long_form": True, + "line_continuation": "\\\\", + "quote_type": "single", + "follow_original_method": False, + "silent_mode": False, + } + if options: + defaults.update(options) + return defaults + + +def indent_str(options: SnippetOptions) -> str: + """Build a single indentation string from resolved *options*.""" + char = "\t" if options.get("indent_type") == "tab" else " " + return char * options.get("indent_count", _DEFAULT_INDENT_COUNT) + + +def prepare_body(body: str | None, options: SnippetOptions) -> str | None: + """Optionally trim whitespace from *body* per *options*.""" + if body is None: + return None + if options.get("trim_body"): + body = body.strip() + return body if body else None + + +# --------------------------------------------------------------------------- +# Language registry +# --------------------------------------------------------------------------- + + +class LanguageEntry(NamedTuple): + """Metadata for a single snippet language/library variant.""" + + display_name: str + lexer: str + applicable_options: tuple[str, ...] + generate: Callable[..., str] + + +def _build_registry() -> dict[str, LanguageEntry]: + """Import all generator modules and build the master registry.""" + from services.http.snippet_generator.compiled_snippets import COMPILED_LANGUAGES + from services.http.snippet_generator.dynamic_snippets import DYNAMIC_LANGUAGES + from services.http.snippet_generator.shell_snippets import SHELL_LANGUAGES + + registry: dict[str, LanguageEntry] = {} + for entries in (SHELL_LANGUAGES, DYNAMIC_LANGUAGES, COMPILED_LANGUAGES): + for entry in entries: + registry[entry.display_name] = entry + return registry + + +# Lazy-initialised singleton +_REGISTRY: dict[str, LanguageEntry] | None = None + + +def _get_registry() -> dict[str, LanguageEntry]: + """Return the language registry, building it on first call.""" + global _REGISTRY + if _REGISTRY is None: + _REGISTRY = _build_registry() + return _REGISTRY + + +# --------------------------------------------------------------------------- +# Public API +# --------------------------------------------------------------------------- + + +class SnippetGenerator: + """Generate code snippets from request parameters. + + Every method is a ``@staticmethod`` — no shared state. + """ + + @staticmethod + def available_languages() -> list[str]: + """Return the sorted list of supported snippet language labels.""" + return sorted(_get_registry().keys()) + + @staticmethod + def get_language_info(name: str) -> LanguageEntry | None: + """Return the :class:`LanguageEntry` for *name*, or ``None``.""" + return _get_registry().get(name) + + @staticmethod + def generate( + language: str, + *, + method: str, + url: str, + headers: str | None = None, + body: str | None = None, + auth: dict | None = None, + options: SnippetOptions | None = None, + ) -> str: + """Generate a snippet for the given language label. + + The *language* parameter should be one of the values returned + by :meth:`available_languages`. Headers are accepted as a raw + newline-separated string and parsed internally. + """ + entry = _get_registry().get(language) + if entry is None: + return f"# Unsupported language: {language}" + + hdr = parse_header_dict(headers) + url, hdr = apply_auth(auth, url, hdr, method=method, body=body) + opts = resolve_options(options) + body = prepare_body(body, opts) + + return entry.generate( + method=method, + url=url, + headers=hdr, + body=body, + options=opts, + ) diff --git a/src/services/http/snippet_generator/shell_snippets.py b/src/services/http/snippet_generator/shell_snippets.py new file mode 100644 index 0000000..c29217d --- /dev/null +++ b/src/services/http/snippet_generator/shell_snippets.py @@ -0,0 +1,236 @@ +"""Shell and CLI snippet generators. + +Generators for command-line tools and text-based formats: +cURL, wget, HTTPie, raw HTTP, and PowerShell. +""" + +from __future__ import annotations + +import json +import shlex + +from services.http.snippet_generator.generator import LanguageEntry, SnippetOptions + + +def curl( + *, + method: str, + url: str, + headers: dict[str, str], + body: str | None = None, + options: SnippetOptions | None = None, +) -> str: + """Generate a cURL command.""" + options = options or {} + use_long = options.get("long_form", True) + multiline = options.get("multiline", True) + continuation = options.get("line_continuation", "\\") + quote = options.get("quote_type", "single") + silent = options.get("silent_mode", False) + + def _quote_url(raw: str) -> str: + """Quote the URL with single or double quotes.""" + if quote == "double": + return f'"{raw}"' + return f"'{raw}'" + + parts = ["curl"] + # Method + if use_long: + parts.append(f"--request {method.upper()}") + else: + parts.append(f"-X {method.upper()}") + parts.append(_quote_url(url)) + # Follow redirect + if options.get("follow_redirect", True): + parts.append("--location" if use_long else "-L") + # Follow original method + if options.get("follow_original_method", False): + parts.append("--post301") + parts.append("--post302") + parts.append("--post303") + # Silent mode + if silent: + parts.append("--silent" if use_long else "-s") + # Timeout + timeout = options.get("request_timeout", 0) + if timeout: + parts.append(f"--max-time {timeout}") + # Headers + for key, value in headers.items(): + flag = "--header" if use_long else "-H" + parts.append(f"{flag} {shlex.quote(f'{key}: {value}')}") + # Body + if body: + flag = "--data" if use_long else "-d" + parts.append(f"{flag} {shlex.quote(body)}") + if multiline: + return f" {continuation}\n ".join(parts) + return " ".join(parts) + + +def http_raw( + *, + method: str, + url: str, + headers: dict[str, str], + body: str | None = None, + options: SnippetOptions | None = None, +) -> str: + """Generate a raw HTTP/1.1 request.""" + # Extract host from URL + from urllib.parse import urlparse + + parsed = urlparse(url) + host = parsed.hostname or "" + path = parsed.path or "/" + if parsed.query: + path = f"{path}?{parsed.query}" + + lines = [f"{method.upper()} {path} HTTP/1.1", f"Host: {host}"] + for key, value in headers.items(): + lines.append(f"{key}: {value}") + if body: + lines.append(f"Content-Length: {len(body.encode())}") + lines.append("") + lines.append(body) + else: + lines.append("") + return "\n".join(lines) + + +def shell_wget( + *, + method: str, + url: str, + headers: dict[str, str], + body: str | None = None, + options: SnippetOptions | None = None, +) -> str: + """Generate a wget command.""" + options = options or {} + parts = ["wget", f"--method={method.upper()}"] + if not options.get("follow_redirect", True): + parts.append("--max-redirect=0") + timeout = options.get("request_timeout", 0) + if timeout: + parts.append(f"--timeout={timeout}") + for key, value in headers.items(): + parts.append(f"--header={shlex.quote(f'{key}: {value}')}") + if body: + parts.append(f"--body-data={shlex.quote(body)}") + parts.append("-O-") + parts.append(shlex.quote(url)) + return " \\\n ".join(parts) + + +def shell_httpie( + *, + method: str, + url: str, + headers: dict[str, str], + body: str | None = None, + options: SnippetOptions | None = None, +) -> str: + """Generate an HTTPie command.""" + options = options or {} + parts = ["http", method.upper(), shlex.quote(url)] + if not options.get("follow_redirect", True): + parts.append("--follow=false") + timeout = options.get("request_timeout", 0) + if timeout: + parts.append(f"--timeout={timeout}") + for key, value in headers.items(): + parts.append(f"{shlex.quote(key)}:{shlex.quote(value)}") + if body: + # Try JSON for inline body + try: + json.loads(body) + parts = ["echo", shlex.quote(body), "|", *parts] + parts.append("--json") + except (json.JSONDecodeError, TypeError): + parts = ["echo", shlex.quote(body), "|", *parts] + return " \\\n ".join(parts) + + +def powershell_restmethod( + *, + method: str, + url: str, + headers: dict[str, str], + body: str | None = None, + options: SnippetOptions | None = None, +) -> str: + """Generate a PowerShell Invoke-RestMethod snippet.""" + options = options or {} + lines: list[str] = [] + + if headers: + hdr_items = ", ".join(f'"{k}" = "{v}"' for k, v in headers.items()) + lines.append(f"$headers = @{{{hdr_items}}}") + lines.append("") + + if body: + escaped = body.replace("'", "''") + lines.append(f"$body = '{escaped}'") + lines.append("") + + call = f'Invoke-RestMethod -Uri "{url}" -Method {method.upper()}' + if headers: + call += " -Headers $headers" + if body: + call += " -Body $body" + # Add content-type if present + ct = headers.get("Content-Type") + if ct: + call += f' -ContentType "{ct}"' + if not options.get("follow_redirect", True): + call += " -MaximumRedirection 0" + timeout = options.get("request_timeout", 0) + if timeout: + call += f" -TimeoutSec {timeout}" + lines.append(call) + return "\n".join(lines) + + +# --------------------------------------------------------------------------- +# Registry entries +# --------------------------------------------------------------------------- + +SHELL_LANGUAGES: list[LanguageEntry] = [ + LanguageEntry( + "cURL", + "bash", + ( + "trim_body", + "request_timeout", + "follow_redirect", + "follow_original_method", + "multiline", + "long_form", + "line_continuation", + "quote_type", + "silent_mode", + ), + curl, + ), + LanguageEntry("HTTP", "http", ("trim_body",), http_raw), + LanguageEntry( + "PowerShell (RestMethod)", + "powershell", + ("trim_body", "request_timeout", "follow_redirect"), + powershell_restmethod, + ), + LanguageEntry( + "Shell (HTTPie)", + "bash", + ("request_timeout", "follow_redirect"), + shell_httpie, + ), + LanguageEntry( + "Shell (wget)", + "bash", + ("indent_count", "indent_type", "trim_body", "request_timeout", "follow_redirect"), + shell_wget, + ), +] diff --git a/src/ui/collections/collection_header.py b/src/ui/collections/collection_header.py index 2dde789..570180c 100644 --- a/src/ui/collections/collection_header.py +++ b/src/ui/collections/collection_header.py @@ -3,18 +3,17 @@ from __future__ import annotations from PySide6.QtCore import Qt, Signal -from PySide6.QtGui import QAction from PySide6.QtWidgets import ( QHBoxLayout, QLabel, QLineEdit, - QMenu, QSizePolicy, QToolButton, QVBoxLayout, QWidget, ) +from ui.collections.new_item_popup import NewItemPopup from ui.styling.icons import phi @@ -50,7 +49,7 @@ def __init__(self, parent: QWidget | None = None) -> None: top_row.addWidget(section_label) top_row.addStretch() - # "New" button with dropdown menu + # "New" button with icon-grid popup self._plus_btn = QToolButton(self) self._plus_btn.setText("New") self._plus_btn.setIcon(phi("plus")) @@ -60,23 +59,13 @@ def __init__(self, parent: QWidget | None = None) -> None: self._plus_btn.setToolTip("Create new collection or request") top_row.addWidget(self._plus_btn) - # Plus-menu - self._plus_menu = QMenu(self) - new_coll_act = QAction(phi("folder-plus"), "New collection", self) - self._plus_menu.addAction(new_coll_act) - self._new_req_act = QAction(phi("file-plus"), "New request", self) - self._new_req_act.setEnabled(False) - self._plus_menu.addAction(self._new_req_act) - - self._plus_btn.clicked.connect( - lambda: self._plus_menu.exec( - self._plus_btn.mapToGlobal(self._plus_btn.rect().bottomLeft()) - ) - ) - new_coll_act.triggered.connect(lambda: self.new_collection_requested.emit(None)) + # Dialog (replaces the old QMenu) + self._popup = NewItemPopup(self) + self._popup.new_request_clicked.connect(self._on_popup_new_request) + self._popup.new_collection_clicked.connect(lambda: self.new_collection_requested.emit(None)) + self._plus_btn.clicked.connect(lambda: self._popup.exec()) self._selected_collection_id: int | None = None - self._new_req_act.triggered.connect(self._on_new_request_clicked) # "Import" button self._import_btn = QToolButton(self) @@ -109,9 +98,7 @@ def __init__(self, parent: QWidget | None = None) -> None: def set_selected_collection_id(self, collection_id: int | None) -> None: """Update the currently selected collection for the 'New request' action.""" self._selected_collection_id = collection_id - self._new_req_act.setEnabled(collection_id is not None) - def _on_new_request_clicked(self) -> None: - """Emit ``new_request_requested`` with the currently selected collection.""" - if self._selected_collection_id is not None: - self.new_request_requested.emit(self._selected_collection_id) + def _on_popup_new_request(self) -> None: + """Emit ``new_request_requested`` -- ``None`` means draft (no collection).""" + self.new_request_requested.emit(None) diff --git a/src/ui/collections/collection_widget.py b/src/ui/collections/collection_widget.py index 1ff61e7..46ce80f 100644 --- a/src/ui/collections/collection_widget.py +++ b/src/ui/collections/collection_widget.py @@ -73,6 +73,9 @@ class CollectionWidget(QWidget): # Emitted when the initial background fetch completes load_finished = Signal() + # Emitted when the user wants a draft (unsaved) request tab + draft_request_requested = Signal() + def __init__(self, parent: QWidget | None = None) -> None: """Initialise the collection widget with header, tree, and loading bar.""" super().__init__(parent) @@ -124,8 +127,6 @@ def __init__(self, parent: QWidget | None = None) -> None: self._loading_bar.setGeometry(0, 0, viewport.width(), 4) self._loading_bar.hide() - self._start_fetch() - # ------------------------------------------------------------------ # Background fetch # ------------------------------------------------------------------ @@ -213,9 +214,13 @@ def _on_collection_moved(self, collection_id: int, new_parent_id: int | None) -> # Create helpers # ------------------------------------------------------------------ def _create_new_request(self, collection_id: int | None = None) -> None: - """Create a new request and add it to the tree.""" + """Create a new request and add it to the tree. + + When *collection_id* is ``None`` (draft mode), emits + :pyattr:`draft_request_requested` instead of persisting. + """ if collection_id is None: - logger.warning("Cannot create request without a collection_id") + self.draft_request_requested.emit() return new_request = self._svc.create_request( collection_id, _DEFAULT_METHOD, _DEFAULT_URL, _DEFAULT_REQUEST_NAME diff --git a/src/ui/collections/new_item_popup.py b/src/ui/collections/new_item_popup.py new file mode 100644 index 0000000..70b31c2 --- /dev/null +++ b/src/ui/collections/new_item_popup.py @@ -0,0 +1,135 @@ +"""Postman-style "Create New" dialog for creating new items. + +Displays a centered dialog window with a tile grid offering options +like "HTTP Request" and "Collection". Opened from the "New" button +in the collection sidebar header. +""" + +from __future__ import annotations + +from PySide6.QtCore import Qt, Signal +from PySide6.QtGui import QEnterEvent +from PySide6.QtWidgets import ( + QDialog, + QGridLayout, + QLabel, + QPushButton, + QSizePolicy, + QVBoxLayout, + QWidget, +) + +from ui.styling.icons import phi + + +class _Tile(QPushButton): + """A clickable icon tile with an icon above a label.""" + + hovered = Signal() + + def __init__( + self, + icon_name: str, + label: str, + parent: QWidget | None = None, + ) -> None: + """Initialise tile with a Phosphor icon and text label.""" + super().__init__(parent) + self.setObjectName("newItemTile") + self.setCursor(Qt.CursorShape.PointingHandCursor) + self.setFixedSize(140, 110) + + layout = QVBoxLayout(self) + layout.setContentsMargins(12, 18, 12, 12) + layout.setSpacing(8) + layout.setAlignment(Qt.AlignmentFlag.AlignCenter) + + icon_label = QLabel() + icon_label.setPixmap(phi(icon_name, size=36).pixmap(36, 36)) + icon_label.setAlignment(Qt.AlignmentFlag.AlignCenter) + icon_label.setAttribute(Qt.WidgetAttribute.WA_TransparentForMouseEvents) + layout.addWidget(icon_label) + + text_label = QLabel(label) + text_label.setObjectName("newItemTileLabel") + text_label.setAlignment(Qt.AlignmentFlag.AlignCenter) + text_label.setAttribute(Qt.WidgetAttribute.WA_TransparentForMouseEvents) + layout.addWidget(text_label) + + def enterEvent(self, event: QEnterEvent) -> None: + """Emit hovered signal on mouse enter.""" + super().enterEvent(event) + self.hovered.emit() + + +class NewItemPopup(QDialog): + """Centered dialog window for creating new requests or collections. + + Opened as a modal dialog from the "New" button — mirrors Postman's + "Create New" window. + """ + + new_request_clicked = Signal() + new_collection_clicked = Signal() + + def __init__(self, parent: QWidget | None = None) -> None: + """Initialise the dialog with a grid of item-type tiles.""" + super().__init__(parent) + self.setWindowTitle("Create New") + self.setObjectName("newItemPopup") + self.setFixedSize(380, 260) + + outer = QVBoxLayout(self) + outer.setContentsMargins(24, 20, 24, 20) + outer.setSpacing(4) + + # Title + title = QLabel("What do you want to create?") + title.setObjectName("newItemTitle") + title.setAlignment(Qt.AlignmentFlag.AlignCenter) + outer.addWidget(title) + + outer.addSpacing(12) + + # Tile grid — centered + grid = QGridLayout() + grid.setSpacing(16) + grid.setAlignment(Qt.AlignmentFlag.AlignCenter) + outer.addLayout(grid) + + # Tiles + http_tile = _Tile("globe", "HTTP Request", self) + collection_tile = _Tile("folder-plus", "Collection", self) + + grid.addWidget(http_tile, 0, 0) + grid.addWidget(collection_tile, 0, 1) + + http_tile.clicked.connect(self._on_http_clicked) + collection_tile.clicked.connect(self._on_collection_clicked) + + # Description area at the bottom + self._description = QLabel() + self._description.setObjectName("newItemDescription") + self._description.setWordWrap(True) + self._description.setAlignment(Qt.AlignmentFlag.AlignCenter) + self._description.setSizePolicy(QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Preferred) + self._description.setFixedHeight(40) + self._description.setText("Create a new HTTP request or collection.") + outer.addWidget(self._description) + + http_tile.hovered.connect( + lambda: self._description.setText("Create a new HTTP request draft tab.") + ) + collection_tile.hovered.connect( + lambda: self._description.setText("Create a new collection to organise your requests.") + ) + + def _on_http_clicked(self) -> None: + """Emit signal and close dialog when HTTP tile is clicked.""" + self.new_request_clicked.emit() + self.accept() + + def _on_collection_clicked(self) -> None: + """Emit signal and close dialog when Collection tile is clicked.""" + self.new_collection_clicked.emit() + self.accept() diff --git a/src/ui/dialogs/code_snippet_dialog.py b/src/ui/dialogs/code_snippet_dialog.py deleted file mode 100644 index f7c01c2..0000000 --- a/src/ui/dialogs/code_snippet_dialog.py +++ /dev/null @@ -1,126 +0,0 @@ -"""Dialog for viewing and copying HTTP request code snippets. - -Shows generated code for the current request in various languages -(cURL, Python, JavaScript). -""" - -from __future__ import annotations - -from typing import ClassVar - -from PySide6.QtCore import Qt -from PySide6.QtGui import QClipboard, QGuiApplication -from PySide6.QtWidgets import ( - QComboBox, - QDialog, - QHBoxLayout, - QLabel, - QPushButton, - QVBoxLayout, - QWidget, -) - -from services.http.snippet_generator import SnippetGenerator -from ui.styling.icons import phi -from ui.widgets.code_editor import CodeEditorWidget - - -class CodeSnippetDialog(QDialog): - """Modal dialog displaying code snippets for a request. - - Instantiate with the request parameters and call :meth:`exec`. - """ - - def __init__( - self, - *, - method: str, - url: str, - headers: str | None = None, - body: str | None = None, - parent: QWidget | None = None, - ) -> None: - """Initialise the code snippet dialog.""" - super().__init__(parent) - self.setWindowTitle("Code Snippet") - self.resize(600, 400) - self.setModal(True) - - self._method = method - self._url = url - self._headers = headers - self._body = body - - layout = QVBoxLayout(self) - layout.setContentsMargins(16, 16, 16, 16) - layout.setSpacing(10) - - # Language selector - top_row = QHBoxLayout() - lang_label = QLabel("Language:") - lang_label.setObjectName("sectionLabel") - top_row.addWidget(lang_label) - - self._lang_combo = QComboBox() - self._lang_combo.addItems(SnippetGenerator.available_languages()) - self._lang_combo.setFixedWidth(200) - self._lang_combo.currentTextChanged.connect(self._refresh) - top_row.addWidget(self._lang_combo) - top_row.addStretch() - layout.addLayout(top_row) - - # Code display - self._code_edit = CodeEditorWidget(read_only=True) - layout.addWidget(self._code_edit, 1) - - # Button row - btn_row = QHBoxLayout() - btn_row.addStretch() - - self._copy_btn = QPushButton("Copy to Clipboard") - self._copy_btn.setIcon(phi("clipboard")) - self._copy_btn.setObjectName("primaryButton") - self._copy_btn.setCursor(Qt.CursorShape.PointingHandCursor) - self._copy_btn.clicked.connect(self._copy_to_clipboard) - btn_row.addWidget(self._copy_btn) - - self._status_label = QLabel("") - self._status_label.setObjectName("mutedLabel") - btn_row.addWidget(self._status_label) - - layout.addLayout(btn_row) - - # Generate initial snippet - self._refresh() - - # -- Language to code-editor language mapping ---------------------- - - _LANG_MAP: ClassVar[dict[str, str]] = { - "cURL": "text", - "Python - requests": "text", - "Python - http.client": "text", - "JavaScript - fetch": "text", - "JavaScript - axios": "text", - } - - def _refresh(self) -> None: - """Regenerate the code snippet for the selected language.""" - lang = self._lang_combo.currentText() - snippet = SnippetGenerator.generate( - lang, - method=self._method, - url=self._url, - headers=self._headers, - body=self._body, - ) - editor_lang = self._LANG_MAP.get(lang, "text") - self._code_edit.set_language(editor_lang) - self._code_edit.set_text(snippet) - self._status_label.setText("") - - def _copy_to_clipboard(self) -> None: - """Copy the current snippet text to the system clipboard.""" - clipboard: QClipboard | None = QGuiApplication.clipboard() - if clipboard is not None: - clipboard.setText(self._code_edit.toPlainText()) - self._status_label.setText("Copied!") diff --git a/src/ui/dialogs/save_request_dialog.py b/src/ui/dialogs/save_request_dialog.py new file mode 100644 index 0000000..ee613b3 --- /dev/null +++ b/src/ui/dialogs/save_request_dialog.py @@ -0,0 +1,265 @@ +"""Save Request dialog — lets the user pick a name and collection for a draft request. + +Shown when saving a request that has not yet been persisted to any +collection (i.e. ``request_id is None``). Displays a searchable +tree of existing collections and a "New Collection" button for +creating one inline. +""" + +from __future__ import annotations + +import logging +from typing import Any + +from PySide6.QtCore import Qt +from PySide6.QtWidgets import ( + QDialog, + QHBoxLayout, + QLabel, + QLineEdit, + QPushButton, + QTreeWidget, + QTreeWidgetItem, + QVBoxLayout, + QWidget, +) + +from services.collection_service import CollectionService +from ui.styling.icons import phi + +logger = logging.getLogger(__name__) + +# Qt data role used to store the collection ID on each tree item +_ROLE_COLLECTION_ID = Qt.ItemDataRole.UserRole + + +class SaveRequestDialog(QDialog): + """Modal dialog for saving a draft request into a collection. + + After the dialog is accepted, call :meth:`request_name` and + :meth:`selected_collection_id` to retrieve the user's choices. + """ + + def __init__( + self, + *, + default_name: str = "Untitled Request", + parent: QWidget | None = None, + ) -> None: + """Initialise the dialog with a name field and collection tree.""" + super().__init__(parent) + self.setWindowTitle("Save Request") + self.setMinimumWidth(420) + self.setMinimumHeight(460) + + self._collection_id: int | None = None + + root = QVBoxLayout(self) + root.setContentsMargins(20, 16, 20, 16) + root.setSpacing(12) + + # -- Request name field ---------------------------------------- + name_label = QLabel("Request name") + name_label.setObjectName("sectionLabel") + root.addWidget(name_label) + + self._name_input = QLineEdit() + self._name_input.setText(default_name) + self._name_input.selectAll() + root.addWidget(self._name_input) + + # -- "Save to" label ------------------------------------------- + save_to_label = QLabel("Save to") + save_to_label.setObjectName("sectionLabel") + root.addWidget(save_to_label) + + # -- Search / filter ------------------------------------------- + self._search_input = QLineEdit() + self._search_input.setPlaceholderText("Search for collection or folder") + search_icon = phi("magnifying-glass") + self._search_input.addAction(search_icon, QLineEdit.ActionPosition.LeadingPosition) + self._search_input.textChanged.connect(self._on_search_changed) + root.addWidget(self._search_input) + + # -- Collection tree ------------------------------------------- + self._tree = QTreeWidget() + self._tree.setObjectName("collectionTree") + self._tree.setHeaderHidden(True) + self._tree.setRootIsDecorated(True) + self._tree.setIndentation(16) + self._tree.itemClicked.connect(self._on_item_clicked) + self._tree.itemDoubleClicked.connect(self._on_item_double_clicked) + self._tree.itemExpanded.connect(self._on_item_expanded) + self._tree.itemCollapsed.connect(self._on_item_collapsed) + root.addWidget(self._tree, 1) + + # -- "New Collection" button ----------------------------------- + new_coll_btn = QPushButton("New Collection") + new_coll_btn.setIcon(phi("folder-plus")) + new_coll_btn.setObjectName("flatAccentButton") + new_coll_btn.setCursor(Qt.CursorShape.PointingHandCursor) + new_coll_btn.clicked.connect(self._on_new_collection) + root.addWidget(new_coll_btn) + + # -- Action buttons -------------------------------------------- + btn_row = QHBoxLayout() + btn_row.setSpacing(8) + btn_row.addStretch() + + cancel_btn = QPushButton("Cancel") + cancel_btn.setObjectName("dismissButton") + cancel_btn.setCursor(Qt.CursorShape.PointingHandCursor) + cancel_btn.clicked.connect(self.reject) + btn_row.addWidget(cancel_btn) + + self._save_btn = QPushButton("Save") + self._save_btn.setObjectName("primaryButton") + self._save_btn.setCursor(Qt.CursorShape.PointingHandCursor) + self._save_btn.setEnabled(False) + self._save_btn.clicked.connect(self.accept) + btn_row.addWidget(self._save_btn) + + root.addLayout(btn_row) + + # -- Populate tree --------------------------------------------- + self._load_collections() + + # ------------------------------------------------------------------ + # Public accessors + # ------------------------------------------------------------------ + def request_name(self) -> str: + """Return the user-entered request name.""" + return self._name_input.text().strip() or "Untitled Request" + + def selected_collection_id(self) -> int | None: + """Return the ID of the selected collection, or ``None``.""" + return self._collection_id + + # ------------------------------------------------------------------ + # Data loading + # ------------------------------------------------------------------ + def _load_collections(self) -> None: + """Fetch all collections and populate the tree widget.""" + tree_data = CollectionService.fetch_all() + self._tree.clear() + folder_icon = phi("folder") + self._build_tree(tree_data, parent_item=None, folder_icon=folder_icon) + + def _build_tree( + self, + node: dict[str, Any], + parent_item: QTreeWidgetItem | None, + folder_icon: Any, + ) -> None: + """Recursively build QTreeWidgetItems from the nested collection dict.""" + for _key, child in sorted(node.items(), key=lambda kv: kv[1].get("name", "")): + if child.get("type") == "folder": + item = QTreeWidgetItem() + item.setText(0, child["name"]) + item.setIcon(0, folder_icon) + item.setData(0, _ROLE_COLLECTION_ID, child["id"]) + item.setChildIndicatorPolicy(QTreeWidgetItem.ChildIndicatorPolicy.ShowIndicator) + if parent_item is not None: + parent_item.addChild(item) + else: + self._tree.addTopLevelItem(item) + children = child.get("children", {}) + if children: + self._build_tree(children, parent_item=item, folder_icon=folder_icon) + + # ------------------------------------------------------------------ + # Search / filter + # ------------------------------------------------------------------ + def _on_search_changed(self, text: str) -> None: + """Filter the tree based on search text, showing matching items and their ancestors.""" + needle = text.strip().lower() + if not needle: + self._set_all_visible(True) + self._tree.collapseAll() + return + # Hide everything, then reveal matches and their ancestors + self._set_all_visible(False) + self._filter_tree_items(needle) + + def _set_all_visible(self, visible: bool) -> None: + """Set visibility on every item in the tree.""" + iterator = _tree_item_iterator(self._tree) + for item in iterator: + item.setHidden(not visible) + + def _filter_tree_items(self, needle: str) -> None: + """Show items matching *needle* and all their ancestors.""" + iterator = _tree_item_iterator(self._tree) + for item in iterator: + if needle in item.text(0).lower(): + _reveal_with_ancestors(item) + + # ------------------------------------------------------------------ + # Slots + # ------------------------------------------------------------------ + def _on_item_expanded(self, item: QTreeWidgetItem) -> None: + """Swap to open-folder icon when a folder is expanded.""" + item.setIcon(0, phi("folder-open")) + + def _on_item_collapsed(self, item: QTreeWidgetItem) -> None: + """Restore closed-folder icon when a folder is collapsed.""" + item.setIcon(0, phi("folder")) + + def _on_item_clicked(self, item: QTreeWidgetItem, _column: int) -> None: + """Update selection state when a collection is clicked.""" + self._collection_id = item.data(0, _ROLE_COLLECTION_ID) + self._save_btn.setEnabled(self._collection_id is not None) + + def _on_item_double_clicked(self, item: QTreeWidgetItem, _column: int) -> None: + """Accept the dialog on double-click.""" + self._on_item_clicked(item, 0) + if self._collection_id is not None: + self.accept() + + def _on_new_collection(self) -> None: + """Create a new top-level collection and add it to the tree.""" + try: + new_coll = CollectionService.create_collection("New Collection") + except Exception: + logger.exception("Failed to create collection from save dialog") + return + folder_icon = phi("folder") + item = QTreeWidgetItem() + item.setText(0, new_coll.name) + item.setIcon(0, folder_icon) + item.setData(0, _ROLE_COLLECTION_ID, new_coll.id) + self._tree.addTopLevelItem(item) + self._tree.setCurrentItem(item) + self._collection_id = new_coll.id + self._save_btn.setEnabled(True) + + +# ------------------------------------------------------------------ +# Module-level helpers +# ------------------------------------------------------------------ +def _tree_item_iterator(tree: QTreeWidget) -> list[QTreeWidgetItem]: + """Return a flat list of every QTreeWidgetItem in *tree*.""" + items: list[QTreeWidgetItem] = [] + + def _walk(parent_item: QTreeWidgetItem) -> None: + items.append(parent_item) + for i in range(parent_item.childCount()): + child = parent_item.child(i) + if child is not None: + _walk(child) + + for i in range(tree.topLevelItemCount()): + top = tree.topLevelItem(i) + if top is not None: + _walk(top) + return items + + +def _reveal_with_ancestors(item: QTreeWidgetItem) -> None: + """Unhide *item* and all its ancestor items, expanding as needed.""" + item.setHidden(False) + parent = item.parent() + while parent is not None: + parent.setHidden(False) + parent.setExpanded(True) + parent = parent.parent() diff --git a/src/ui/main_window/draft_controller.py b/src/ui/main_window/draft_controller.py new file mode 100644 index 0000000..92fd5a5 --- /dev/null +++ b/src/ui/main_window/draft_controller.py @@ -0,0 +1,175 @@ +"""Draft request tab lifecycle mixin for the main window. + +Provides ``_DraftControllerMixin`` with methods to open an unsaved +("draft") request tab and to save it via the save-to-collection dialog. +Mixed into ``MainWindow``. +""" + +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING, cast + +from services.collection_service import CollectionService, RequestLoadDict +from ui.request.navigation.tab_manager import TabContext +from ui.request.request_editor import RequestEditorWidget +from ui.request.response_viewer import ResponseViewerWidget + +if TYPE_CHECKING: + from PySide6.QtWidgets import QPushButton, QStackedWidget, QWidget + + from ui.collections.collection_widget import CollectionWidget + from ui.request.navigation.breadcrumb_bar import BreadcrumbBar + from ui.request.navigation.request_tab_bar import RequestTabBar + +logger = logging.getLogger(__name__) + +# Default label for unsaved request tabs +_DRAFT_TAB_NAME = "Untitled Request" + + +class _DraftControllerMixin: + """Mixin that manages draft (unsaved) request tab lifecycle. + + Expects the host class to provide ``_tabs``, ``_tab_bar``, + ``_editor_stack``, ``_response_stack``, ``_breadcrumb_bar``, + ``_save_btn``, ``collection_widget``, and the signal helper + methods from ``_TabControllerMixin``. + """ + + if TYPE_CHECKING: + _tabs: dict[int, TabContext] + _tab_bar: RequestTabBar + _editor_stack: QStackedWidget + _response_stack: QStackedWidget + _breadcrumb_bar: BreadcrumbBar + _save_btn: QPushButton + request_widget: RequestEditorWidget + response_widget: ResponseViewerWidget + collection_widget: CollectionWidget + + def _on_send_request(self) -> None: ... + def _on_save_request(self) -> None: ... + def _on_save_response(self, data: dict) -> None: ... + def _sync_save_btn(self, dirty: bool) -> None: ... + def _on_tab_changed(self, index: int) -> None: ... + + # ------------------------------------------------------------------ + # Open a new draft request tab + # ------------------------------------------------------------------ + def _open_draft_request(self) -> None: + """Open a new draft request tab that is not yet persisted to the DB. + + The tab has ``request_id=None`` and is marked dirty immediately so + the Save button is enabled. Saving triggers the save-to-collection + dialog. + """ + data: RequestLoadDict = { + "name": _DRAFT_TAB_NAME, + "method": "GET", + "url": "", + } + + editor = RequestEditorWidget() + viewer = ResponseViewerWidget() + + self._editor_stack.addWidget(editor) + self._response_stack.addWidget(viewer) + + ctx = TabContext( + request_id=None, + editor=editor, + response_viewer=viewer, + ) + + self._tab_bar.blockSignals(True) + try: + idx = self._tab_bar.add_request_tab("GET", _DRAFT_TAB_NAME) + finally: + self._tab_bar.blockSignals(False) + + ctx.draft_name = _DRAFT_TAB_NAME + self._tabs[idx] = ctx + + editor.load_request(data, request_id=None) + editor.send_requested.connect(self._on_send_request) + editor.save_requested.connect(self._on_save_request) + editor.dirty_changed.connect(self._sync_save_btn) + viewer.save_response_requested.connect(self._on_save_response) + + # Mark as dirty so Save button is enabled for the new draft + editor._set_dirty(True) + + self._tab_bar.setCurrentIndex(idx) + self._on_tab_changed(idx) + + # ------------------------------------------------------------------ + # Save draft request → save-to-collection dialog + # ------------------------------------------------------------------ + def _save_draft_request(self, ctx: TabContext | None, editor: RequestEditorWidget) -> None: + """Open the save-to-collection dialog for a draft (unsaved) request. + + On accept, creates the request in the DB and upgrades the tab + from draft to a normal persisted request. + """ + from ui.dialogs.save_request_dialog import SaveRequestDialog + + # Prefer the user-chosen draft name (set via breadcrumb rename), + # then the URL text, then the default placeholder. + idx = self._tab_bar.currentIndex() + draft_ctx = self._tabs.get(idx) + draft_label = draft_ctx.draft_name if draft_ctx is not None else None + url_text = editor._url_input.text().strip() + default_name = url_text or draft_label or _DRAFT_TAB_NAME + dialog = SaveRequestDialog(default_name=default_name, parent=cast("QWidget", self)) + if dialog.exec() != dialog.DialogCode.Accepted: + return + + request_name = dialog.request_name() + collection_id = dialog.selected_collection_id() + if collection_id is None: + return + + data = editor.get_request_data() + method = data.get("method", "GET") + url = data.get("url", "") + try: + new_request = CollectionService.create_request( + collection_id, + method, + url, + request_name, + body=data.get("body"), + request_parameters=data.get("request_parameters"), + headers=data.get("headers"), + scripts=data.get("scripts"), + ) + except Exception: + logger.exception("Failed to create request in collection %s", collection_id) + return + + # Upgrade draft tab to a persisted request + editor._request_id = new_request.id + editor._set_dirty(False) + + if ctx is not None: + ctx.request_id = new_request.id + idx = self._tab_bar.currentIndex() + display_name = url if url else request_name + self._tab_bar.update_tab(idx, method=method, name=display_name, is_dirty=False) + # Refresh breadcrumb + crumbs = CollectionService.get_request_breadcrumb(new_request.id) + self._breadcrumb_bar.set_path(crumbs) + + # Add the request to the tree sidebar + self.collection_widget._tree_widget.add_request( + { + "name": new_request.name, + "url": new_request.url, + "id": new_request.id, + "method": new_request.method, + }, + collection_id, + ) + self.collection_widget.select_and_scroll_to(new_request.id, "request") + logger.info("Draft saved as request id=%s in collection=%s", new_request.id, collection_id) diff --git a/src/ui/main_window/send_pipeline.py b/src/ui/main_window/send_pipeline.py index 4a2264d..148df5a 100644 --- a/src/ui/main_window/send_pipeline.py +++ b/src/ui/main_window/send_pipeline.py @@ -75,8 +75,8 @@ def _on_send_request(self) -> None: from services.collection_service import CollectionService auth_data = editor._get_auth_data() - if ctx and ctx.request_id and (not auth_data or auth_data.get("type") in (None, "noauth")): - inherited = CollectionService.get_request_auth_chain(ctx.request_id) + if ctx and ctx.request_id and auth_data is None: + inherited = CollectionService.get_request_inherited_auth(ctx.request_id) if inherited: auth_data = inherited diff --git a/src/ui/main_window/tab_controller.py b/src/ui/main_window/tab_controller.py index dc88156..4e76c17 100644 --- a/src/ui/main_window/tab_controller.py +++ b/src/ui/main_window/tab_controller.py @@ -68,6 +68,8 @@ def _refresh_variable_map( request_id: int | None, local_overrides: dict | None = ..., ) -> None: ... + def _refresh_sidebar(self) -> None: ... + def _schedule_sidebar_snippet_refresh(self) -> None: ... # ------------------------------------------------------------------ # Open request @@ -173,6 +175,7 @@ def _create_tab( editor.send_requested.connect(self._on_send_request) editor.save_requested.connect(self._on_save_request) editor.dirty_changed.connect(self._sync_save_btn) + editor.request_changed.connect(lambda _: self._schedule_sidebar_snippet_refresh()) viewer.save_response_requested.connect(self._on_save_response) # Now switch to the tab (triggers _on_tab_changed safely) @@ -238,6 +241,11 @@ def _on_tab_changed(self, index: int) -> None: if ctx.request_id is not None: crumbs = CollectionService.get_request_breadcrumb(ctx.request_id) self._breadcrumb_bar.set_path(crumbs) + elif ctx.draft_name is not None: + # Draft tab — show editable single-segment breadcrumb + self._breadcrumb_bar.set_path( + [{"name": ctx.draft_name, "type": "request", "id": 0}] + ) else: self._breadcrumb_bar.clear() # Load saved responses @@ -255,6 +263,9 @@ def _on_tab_changed(self, index: int) -> None: self._breadcrumb_bar.clear() self._save_btn.setVisible(False) + # Refresh right sidebar for the active tab + self._refresh_sidebar() + # ------------------------------------------------------------------ # Tab close # ------------------------------------------------------------------ @@ -290,6 +301,7 @@ def _on_tab_close(self, index: int) -> None: editor.send_requested.disconnect(self._on_send_request) editor.save_requested.disconnect(self._on_save_request) editor.dirty_changed.disconnect(self._sync_save_btn) + editor.request_changed.disconnect() viewer.save_response_requested.disconnect(self._on_save_response) # Remove from stacked widgets and detach from parent hierarchy. @@ -451,6 +463,15 @@ def _on_breadcrumb_clicked(self, item_type: str, item_id: int) -> None: def _on_breadcrumb_rename(self, new_name: str) -> None: """Rename the current request/folder from the breadcrumb bar.""" + idx = self._tab_bar.currentIndex() + ctx = self._tabs.get(idx) + + # Draft tab — no DB entry yet, update tab name and context only + if ctx is not None and ctx.request_id is None and ctx.draft_name is not None: + ctx.draft_name = new_name + self._tab_bar.update_tab(idx, name=new_name) + return + seg = self._breadcrumb_bar.last_segment_info if seg is None: return diff --git a/src/ui/main_window/variable_controller.py b/src/ui/main_window/variable_controller.py index 154bc5b..9354f7b 100644 --- a/src/ui/main_window/variable_controller.py +++ b/src/ui/main_window/variable_controller.py @@ -1,8 +1,8 @@ -"""Variable management mixin for the main window. +"""Variable and sidebar management mixin for the main window. Provides ``_VariableControllerMixin`` with environment variable -refresh, update, local override, and unresolved-variable callbacks. -Mixed into ``MainWindow``. +refresh, update, local override, unresolved-variable callbacks, and +right-sidebar refresh logic. Mixed into ``MainWindow``. """ from __future__ import annotations @@ -10,6 +10,8 @@ import logging from typing import TYPE_CHECKING +from PySide6.QtCore import QTimer + from services.environment_service import EnvironmentService if TYPE_CHECKING: @@ -17,20 +19,24 @@ from ui.environments.environment_selector import EnvironmentSelector from ui.request.navigation.tab_manager import TabContext from ui.request.request_editor import RequestEditorWidget + from ui.sidebar import RightSidebar logger = logging.getLogger(__name__) class _VariableControllerMixin: - """Mixin that manages variable maps and popup callbacks. + """Mixin that manages variable maps, popup callbacks, and sidebar. Expects the host class to provide ``_tabs``, ``_env_selector``, - ``_tab_bar``, ``request_widget``, and ``_current_tab_context()``. + ``_tab_bar``, ``request_widget``, ``_right_sidebar``, and + ``_current_tab_context()``. """ # -- Host-class interface (declared for mypy) ----------------------- _env_selector: EnvironmentSelector _tabs: dict[int, TabContext] + _right_sidebar: RightSidebar + _sidebar_debounce: QTimer def _current_tab_context(self) -> TabContext | None: ... @@ -70,6 +76,7 @@ def _on_environment_changed(self, _env_id: object) -> None: for ctx in self._tabs.values(): if ctx.tab_type != "folder": self._refresh_variable_map(ctx.editor, ctx.request_id, ctx.local_overrides) + self._refresh_sidebar() def _on_variable_updated( self, @@ -174,3 +181,126 @@ def _on_add_unresolved_variable( tab_ctx.request_id, tab_ctx.local_overrides, ) + self._refresh_sidebar() + + # ------------------------------------------------------------------ + # Right-sidebar helpers + # ------------------------------------------------------------------ + def _refresh_sidebar(self) -> None: + """Update the right sidebar panels for the active tab.""" + ctx = self._current_tab_context() + env_id = self._env_selector.current_environment_id() + has_env = env_id is not None + + if ctx is None: + self._right_sidebar.clear() + return + + if ctx.tab_type == "folder": + variables = EnvironmentService.build_combined_variable_detail_map(env_id, None) + # Merge folder-level variables from the collection chain + if ctx.collection_id is not None: + from database.models.collections.collection_query_repository import ( + get_collection_variable_chain_detailed, + ) + + for key, (value, coll_id) in get_collection_variable_chain_detailed( + ctx.collection_id + ).items(): + if key not in variables: + variables[key] = { + "value": value, + "source": "collection", + "source_id": coll_id, + } + self._right_sidebar.show_folder_panels(variables, has_environment=has_env) + else: + variables = EnvironmentService.build_combined_variable_detail_map( + env_id, ctx.request_id + ) + # Layer per-request overrides on top + if ctx.local_overrides: + for key, override in ctx.local_overrides.items(): + variables[key] = { + "value": override["value"], + "source": override["original_source"], + "source_id": override["original_source_id"], + "is_local": True, + } + editor = ctx.editor + data = editor.get_request_data() + # Resolve {{variable}} placeholders for the snippet. + flat_vars = {k: v["value"] for k, v in variables.items()} + sub = EnvironmentService.substitute + # Resolve inherited auth for sidebar / snippet display + auth = data.get("auth") + if auth is None and ctx.request_id: + from services.collection_service import CollectionService + + auth = CollectionService.get_request_inherited_auth(ctx.request_id) + self._right_sidebar.show_request_panels( + variables, + local_overrides=ctx.local_overrides, + has_environment=has_env, + method=editor._method_combo.currentText(), + url=sub(editor._url_input.text().strip(), flat_vars), + headers=sub(editor.get_headers_text() or "", flat_vars) or None, + body=sub(data.get("body") or "", flat_vars) or None, + auth=auth, + ) + + def _schedule_sidebar_snippet_refresh(self) -> None: + """Debounce snippet refresh (300 ms) on request editor changes.""" + self._sidebar_debounce.start(300) + + def _refresh_sidebar_snippet(self) -> None: + """Regenerate only the snippet panel for the active request tab.""" + ctx = self._current_tab_context() + if ctx is None or ctx.tab_type == "folder": + return + editor = ctx.editor + data = editor.get_request_data() + # Resolve {{variable}} placeholders for the snippet. + env_id = self._env_selector.current_environment_id() + variables = EnvironmentService.build_combined_variable_detail_map(env_id, ctx.request_id) + if ctx.local_overrides: + for key, override in ctx.local_overrides.items(): + variables[key] = { + "value": override["value"], + "source": override["original_source"], + "source_id": override["original_source_id"], + "is_local": True, + } + flat_vars = {k: v["value"] for k, v in variables.items()} + sub = EnvironmentService.substitute + self._right_sidebar.snippet_panel.update_request( + method=editor._method_combo.currentText(), + url=sub(editor._url_input.text().strip(), flat_vars), + headers=sub(editor.get_headers_text() or "", flat_vars) or None, + body=sub(data.get("body") or "", flat_vars) or None, + auth=self._resolve_snippet_auth(data.get("auth"), ctx.request_id), + ) + + def _resolve_snippet_auth(self, auth: dict | None, request_id: int | None) -> dict | None: + """Return the effective auth for snippet generation. + + If the request uses "Inherit auth from parent" (``auth is None``) + and has a saved request_id, resolve the inherited auth from the + collection chain. + """ + if auth is None and request_id: + from services.collection_service import CollectionService + + return CollectionService.get_request_inherited_auth(request_id) + return auth + + def _toggle_right_sidebar(self) -> None: + """Toggle the right sidebar panel open or closed.""" + if self._right_sidebar.panel_open: + self._right_sidebar._close_panel() + else: + self._right_sidebar.open_panel("variables") + + def _on_snippet_shortcut(self) -> None: + """Open the sidebar with the snippet panel visible.""" + self._right_sidebar.open_panel("snippet") diff --git a/src/ui/main_window/window.py b/src/ui/main_window/window.py index d0b6aac..6cd218e 100644 --- a/src/ui/main_window/window.py +++ b/src/ui/main_window/window.py @@ -5,7 +5,7 @@ import logging from typing import TYPE_CHECKING, Any -from PySide6.QtCore import QSize, Qt, QThread +from PySide6.QtCore import QSize, Qt, QThread, QTimer from PySide6.QtGui import QAction, QCloseEvent, QCursor, QGuiApplication, QKeySequence if TYPE_CHECKING: @@ -28,6 +28,7 @@ from ui.collections.collection_widget import CollectionWidget from ui.environments.environment_selector import EnvironmentSelector from ui.loading_screen import LoadingScreen +from ui.main_window.draft_controller import _DraftControllerMixin from ui.main_window.send_pipeline import _SendPipelineMixin from ui.main_window.tab_controller import _TabControllerMixin from ui.main_window.variable_controller import _VariableControllerMixin @@ -38,6 +39,7 @@ from ui.request.navigation.tab_manager import TabContext from ui.request.request_editor import RequestEditorWidget from ui.request.response_viewer import ResponseViewerWidget +from ui.sidebar import RightSidebar from ui.styling.icons import phi from ui.styling.theme_manager import ThemeManager @@ -47,13 +49,14 @@ class MainWindow( _SendPipelineMixin, _VariableControllerMixin, + _DraftControllerMixin, _TabControllerMixin, QMainWindow, ): """Top-level application window. - Sets up the menu bar, toolbar, and the three-pane layout - (collection sidebar | request editor | response viewer). + Sets up the menu bar, toolbar, and four-pane layout + (collection sidebar | request editor | response viewer | right sidebar rail). """ def __init__(self, theme_manager: ThemeManager | None = None) -> None: @@ -80,11 +83,22 @@ def __init__(self, theme_manager: ThemeManager | None = None) -> None: self.collection_widget = CollectionWidget(self) + # Right sidebar (created before _setup_ui so layout can embed it) + self._right_sidebar = RightSidebar() + + # Debounce timer for live snippet updates in the sidebar + self._sidebar_debounce = QTimer(self) + self._sidebar_debounce.setSingleShot(True) + self._sidebar_debounce.timeout.connect(self._refresh_sidebar_snippet) + self._setup_ui() # Wire sidebar -> editor self.collection_widget.item_action_triggered.connect(self._on_item_action) + # Wire draft request + self.collection_widget.draft_request_requested.connect(self._open_draft_request) + # Wire save -> save pipeline self.request_widget.save_requested.connect(self._on_save_request) self.request_widget.dirty_changed.connect(self._sync_save_btn) @@ -122,24 +136,23 @@ def __init__(self, theme_manager: ThemeManager | None = None) -> None: # Wire tree rename -> update open tabs self.collection_widget.item_name_changed.connect(self._on_item_name_changed) + # Start the collection fetch *after* all signals are connected so + # a fast-completing fetch cannot emit load_finished before we listen. + self.collection_widget._start_fetch() + # ---- Move to the screen that contains the mouse -------------- self._move_to_mouse_screen() def _move_to_mouse_screen(self) -> None: """Center the window on the monitor that the cursor is on.""" - # 1. Find the screen that the cursor is currently on - cursor_pos = QCursor.pos() # global screen coordinates + cursor_pos = QCursor.pos() screen = QGuiApplication.screenAt(cursor_pos) - - # 2. If we found a screen, move the window so it is centered there if screen is not None: - screen_geom = screen.availableGeometry() # skip taskbars, docks - win_geom = self.frameGeometry() # includes frame + screen_geom = screen.availableGeometry() + win_geom = self.frameGeometry() win_geom.moveCenter(screen_geom.center()) self.move(win_geom.topLeft()) - # 3. If screen is None (rare), just leave the window where Qt chose - # ------------------------------------------------------------------ # Menu creation # ------------------------------------------------------------------ @@ -150,6 +163,14 @@ def _create_menus(self) -> None: # File menu file_menu = menubar.addMenu("&File") + new_req_act = QAction("&New Request", self) + new_req_act.setIcon(phi("plus")) + new_req_act.setShortcut(QKeySequence("Ctrl+N")) + new_req_act.triggered.connect(self._open_draft_request) + file_menu.addAction(new_req_act) + + file_menu.addSeparator() + import_act = QAction("&Import...", self) import_act.setIcon(phi("download-simple")) import_act.setShortcut(QKeySequence("Ctrl+I")) @@ -167,7 +188,7 @@ def _create_menus(self) -> None: snippet_act = QAction("Generate Code &Snippet\u2026", self) snippet_act.setIcon(phi("code")) snippet_act.setShortcut(QKeySequence("Ctrl+Shift+C")) - snippet_act.triggered.connect(self._on_code_snippet) + snippet_act.triggered.connect(self._on_snippet_shortcut) file_menu.addAction(snippet_act) file_menu.addSeparator() @@ -346,7 +367,7 @@ def _setup_ui(self) -> None: self._main_stack.addWidget(self._loading_screen) self._loading_screen.start_animation() - # 3. Main splitter: left (nav) + right (request+response) + # 3. Main splitter: left (nav) + right (request+response+sidebar) central = QWidget() main_layout = QHBoxLayout(central) main_layout.setContentsMargins(9, 0, 0, 0) @@ -359,12 +380,21 @@ def _setup_ui(self) -> None: # --- Left navigation pane --- self._main_splitter.addWidget(self.collection_widget) - # --- Right side: vertical splitter (request + response) --- + # --- Centre: vertical splitter (request + response) --- request_area = self._build_request_area() self._right_splitter = QSplitter(Qt.Orientation.Vertical) - self._main_splitter.addWidget(self._right_splitter) - self._main_splitter.setStretchFactor(1, 3) # right side takes 3x the space + + # --- Content area: centre panes + right sidebar rail --- + self._content_splitter = QSplitter(Qt.Orientation.Horizontal) + self._content_splitter.setHandleWidth(4) + self._main_splitter.addWidget(self._content_splitter) + self._main_splitter.setStretchFactor(1, 3) + + self._content_splitter.addWidget(self._right_splitter) + self._right_sidebar.install_in_splitter(self._content_splitter) + self._content_splitter.setStretchFactor(0, 1) + self._content_splitter.setCollapsible(0, False) # --- Request editor area --- self._right_splitter.addWidget(request_area) @@ -430,27 +460,6 @@ def _toggle_layout_orientation(self) -> None: # ------------------------------------------------------------------ # Dialogs # ------------------------------------------------------------------ - def _on_code_snippet(self) -> None: - """Open the code snippet dialog for the current request.""" - from ui.dialogs.code_snippet_dialog import CodeSnippetDialog - - ctx = self._current_tab_context() - if ctx is not None and ctx.tab_type == "folder": - return - editor = ctx.editor if ctx else self.request_widget - method = editor._method_combo.currentText() - url = editor._url_input.text().strip() - headers = editor.get_headers_text() - body = editor.get_request_data().get("body") or None - dialog = CodeSnippetDialog( - method=method, - url=url, - headers=headers, - body=body, - parent=self, - ) - dialog.exec() - def _on_settings(self) -> None: """Open the settings dialog.""" from ui.dialogs.settings_dialog import SettingsDialog @@ -484,6 +493,8 @@ def _on_save_request(self) -> None: """Save the current request editor contents to the database. For folder tabs, triggers an immediate auto-save instead. + For draft tabs (``request_id is None``), opens the + save-to-collection dialog. """ ctx = self._current_tab_context() @@ -497,8 +508,15 @@ def _on_save_request(self) -> None: editor = ctx.editor if ctx else self.request_widget request_id = editor.request_id + + # Draft request -- open save-to-collection dialog + # Only when an actual tab exists (ctx is not None); a bare editor + # with no tab is not saveable. if request_id is None: + if ctx is not None: + self._save_draft_request(ctx, editor) return + if not editor.is_dirty: return diff --git a/src/ui/request/auth/__init__.py b/src/ui/request/auth/__init__.py new file mode 100644 index 0000000..86574bf --- /dev/null +++ b/src/ui/request/auth/__init__.py @@ -0,0 +1,33 @@ +"""Auth sub-package — shared auth UI, serialisation, and type definitions. + +Re-exports the public API used by :class:`RequestEditorWidget`, +:class:`FolderEditorWidget`, and other consumers. +""" + +from __future__ import annotations + +from ui.request.auth.auth_field_specs import AUTH_FIELD_SPECS +from ui.request.auth.auth_mixin import _AuthMixin +from ui.request.auth.auth_pages import ( + AUTH_FIELD_ORDER, + AUTH_KEY_TO_DISPLAY, + AUTH_PAGE_INDEX, + AUTH_TYPE_DESCRIPTIONS, + AUTH_TYPE_KEYS, + AUTH_TYPE_LABELS, + AUTH_TYPES, +) +from ui.request.auth.oauth2_page import OAuth2Page + +__all__ = [ + "AUTH_FIELD_ORDER", + "AUTH_FIELD_SPECS", + "AUTH_KEY_TO_DISPLAY", + "AUTH_PAGE_INDEX", + "AUTH_TYPES", + "AUTH_TYPE_DESCRIPTIONS", + "AUTH_TYPE_KEYS", + "AUTH_TYPE_LABELS", + "OAuth2Page", + "_AuthMixin", +] diff --git a/src/ui/request/auth/auth_field_specs.py b/src/ui/request/auth/auth_field_specs.py new file mode 100644 index 0000000..49c686d --- /dev/null +++ b/src/ui/request/auth/auth_field_specs.py @@ -0,0 +1,293 @@ +"""Per-type field specifications for all supported auth types. + +Each key maps to a tuple of :class:`FieldSpec` descriptors that drive +the data-driven page builder in :mod:`auth_pages`. +""" + +from __future__ import annotations + +from ui.request.auth.auth_pages import FieldSpec + +AUTH_FIELD_SPECS: dict[str, tuple[FieldSpec, ...]] = { + "bearer": (FieldSpec("token", "Token", placeholder="Enter bearer token"),), + "basic": ( + FieldSpec("username", "Username", placeholder="Username"), + FieldSpec("password", "Password", kind="password", placeholder="Password"), + ), + "apikey": ( + FieldSpec("key", "Key", placeholder="Header or query parameter name"), + FieldSpec("value", "Value", placeholder="API key value"), + FieldSpec( + "in", + "Add to", + kind="combo", + options=("Header", "Query Params"), + width=140, + combo_map={"header": "Header", "query": "Query Params"}, + default="header", + ), + ), + "digest": ( + FieldSpec("username", "Username", placeholder="Username"), + FieldSpec("password", "Password", kind="password", placeholder="Password"), + FieldSpec("realm", "Realm", placeholder="Realm", advanced=True), + FieldSpec("nonce", "Nonce", placeholder="Server nonce", advanced=True), + FieldSpec( + "algorithm", + "Algorithm", + kind="combo", + options=( + "MD5", + "MD5-sess", + "SHA-256", + "SHA-256-sess", + "SHA-512-256", + "SHA-512-256-sess", + ), + default="MD5", + advanced=True, + ), + FieldSpec( + "qop", + "Quality of Protection", + placeholder="auth or auth-int", + advanced=True, + ), + FieldSpec("nonceCount", "Nonce Count", placeholder="00000001", advanced=True), + FieldSpec("clientNonce", "Client Nonce", placeholder="Client nonce", advanced=True), + FieldSpec("opaque", "Opaque", placeholder="Opaque string from server", advanced=True), + ), + "oauth1": ( + FieldSpec( + "signatureMethod", + "Signature Method", + kind="combo", + options=("HMAC-SHA1", "HMAC-SHA256", "PLAINTEXT", "RSA-SHA1"), + default="HMAC-SHA1", + ), + FieldSpec("consumerKey", "Consumer Key", placeholder="Consumer key"), + FieldSpec( + "consumerSecret", + "Consumer Secret", + kind="password", + placeholder="Consumer secret", + ), + FieldSpec("token", "Access Token", placeholder="Access token"), + FieldSpec("tokenSecret", "Token Secret", kind="password", placeholder="Token secret"), + FieldSpec( + "addParamsToHeader", + "Add auth data to", + kind="combo", + options=("Request Headers", "Request Body", "Request URL"), + combo_map={ + "true": "Request Headers", + "false": "Request URL", + "body": "Request Body", + }, + default="true", + save_as_bool=False, + ), + FieldSpec( + "callbackUrl", + "Callback URL", + placeholder="https://example.com/callback", + advanced=True, + ), + FieldSpec("verifier", "Verifier", placeholder="Verifier (optional)", advanced=True), + FieldSpec("timestamp", "Timestamp", placeholder="Auto-generated if empty", advanced=True), + FieldSpec("nonce", "Nonce", placeholder="Auto-generated if empty", advanced=True), + FieldSpec("version", "Version", placeholder="1.0", default="1.0", advanced=True), + FieldSpec("realm", "Realm", placeholder="Realm (optional)", advanced=True), + FieldSpec( + "includeBodyHash", + "Include body hash", + kind="checkbox", + default="false", + save_as_bool=True, + advanced=True, + ), + FieldSpec( + "addEmptyParamsToSign", + "Add empty parameters to signature", + kind="checkbox", + default="false", + save_as_bool=True, + advanced=True, + ), + ), + "oauth2": (), + "hawk": ( + FieldSpec("authId", "Hawk Auth ID", placeholder="Auth ID"), + FieldSpec("authKey", "Hawk Auth Key", kind="password", placeholder="Auth key"), + FieldSpec( + "algorithm", + "Algorithm", + kind="combo", + options=("sha256", "sha1"), + default="sha256", + ), + FieldSpec("user", "User", placeholder="Username", advanced=True), + FieldSpec("nonce", "Nonce", placeholder="Nonce", advanced=True), + FieldSpec( + "extraData", + "ext", + placeholder="e.g. some-app-extra-data", + advanced=True, + ), + FieldSpec("appId", "app", placeholder="Application ID", advanced=True), + FieldSpec("delegation", "dlg", placeholder="e.g. delegated-by", advanced=True), + FieldSpec("timestamp", "Timestamp", placeholder="Timestamp", advanced=True), + FieldSpec( + "includePayloadHash", + "Include payload hash", + kind="checkbox", + default="false", + save_as_bool=True, + advanced=True, + ), + ), + "awsv4": ( + FieldSpec("accessKey", "Access Key", placeholder="AWS access key ID"), + FieldSpec("secretKey", "Secret Key", kind="password", placeholder="AWS secret key"), + FieldSpec("region", "AWS Region", placeholder="us-east-1", default="us-east-1"), + FieldSpec("service", "Service Name", placeholder="e.g. s3, execute-api"), + FieldSpec( + "sessionToken", + "Session Token", + placeholder="Optional session token", + advanced=True, + ), + FieldSpec( + "addAuthDataTo", + "Add auth data to", + kind="combo", + options=("Header", "Query Params"), + combo_map={"header": "Header", "queryParams": "Query Params"}, + default="header", + advanced=True, + ), + ), + "jwt": ( + FieldSpec( + "algorithm", + "Algorithm", + kind="combo", + options=( + "HS256", + "HS384", + "HS512", + "RS256", + "RS384", + "RS512", + "ES256", + "ES384", + "ES512", + "PS256", + "PS384", + "PS512", + ), + default="HS256", + ), + FieldSpec("secret", "Secret", kind="password", placeholder="Secret key (HMAC)"), + FieldSpec( + "privateKey", + "Private Key", + kind="textarea", + placeholder="RSA/EC private key in PEM format", + ), + FieldSpec("payload", "Payload", kind="textarea", placeholder='{"sub": "1234567890"}'), + FieldSpec( + "headers", + "Custom Headers", + kind="textarea", + placeholder='{"kid": "my-key-id"}', + ), + FieldSpec( + "isSecretBase64Encoded", + "Secret is Base64 encoded", + kind="combo", + options=("No", "Yes"), + combo_map={"false": "No", "true": "Yes"}, + default="false", + save_as_bool=True, + ), + FieldSpec( + "addTokenTo", + "Add token to", + kind="combo", + options=("Header", "Query Params"), + combo_map={"header": "Header", "queryParams": "Query Params"}, + default="header", + ), + FieldSpec("headerPrefix", "Header Prefix", placeholder="Bearer", default="Bearer"), + FieldSpec("queryParamKey", "Query Param Key", placeholder="token", default="token"), + ), + "asap": ( + FieldSpec( + "algorithm", + "Algorithm", + kind="combo", + options=("RS256", "PS256"), + default="RS256", + ), + FieldSpec("issuer", "Issuer", placeholder="Issuer"), + FieldSpec("audience", "Audience", placeholder="Audience"), + FieldSpec("kid", "Key ID", placeholder="Key ID"), + FieldSpec( + "privateKey", + "Private Key", + kind="textarea", + placeholder="-----BEGIN PRIVATE KEY-----\n" + "PASTE YOUR PRIVATE KEY\nHERE...\n" + "-----END PRIVATE KEY-----", + ), + FieldSpec("subject", "Subject", placeholder="Subject", advanced=True), + FieldSpec( + "claims", + "Additional claims", + kind="textarea", + placeholder="{}", + advanced=True, + ), + FieldSpec( + "expiresIn", + "Expiry", + placeholder="e.g. 3600, default: 1h", + default="3600", + advanced=True, + ), + ), + "ntlm": ( + FieldSpec("username", "Username", placeholder="Username"), + FieldSpec("password", "Password", kind="password", placeholder="Password"), + FieldSpec("domain", "Domain", placeholder="e.g. example.com", advanced=True), + FieldSpec("workstation", "Workstation", placeholder="e.g. John-PC", advanced=True), + ), + "edgegrid": ( + FieldSpec("accessToken", "Access Token", placeholder="Access token"), + FieldSpec("clientToken", "Client Token", placeholder="Client token"), + FieldSpec( + "clientSecret", + "Client Secret", + kind="password", + placeholder="Client secret", + ), + FieldSpec("nonce", "Nonce", placeholder="Nonce", advanced=True), + FieldSpec("timestamp", "Timestamp", placeholder="Timestamp", advanced=True), + FieldSpec("baseURL", "Base URL", placeholder="Base URL", advanced=True), + FieldSpec( + "headersToSign", + "Headers to Sign", + placeholder="Headers to Sign", + advanced=True, + ), + FieldSpec( + "maxBody", + "Max Body Size", + placeholder="Max Body Size", + default="131072", + advanced=True, + suffix="bytes", + ), + ), +} diff --git a/src/ui/request/auth/auth_mixin.py b/src/ui/request/auth/auth_mixin.py new file mode 100644 index 0000000..148fd6d --- /dev/null +++ b/src/ui/request/auth/auth_mixin.py @@ -0,0 +1,314 @@ +"""Shared auth-tab mixin for request and folder editors. + +Provides :class:`_AuthMixin` containing auth UI construction +(type selector, stacked field pages for all supported auth types), +inherit-preview logic, and load / save / clear helpers. + +Mixed into both :class:`RequestEditorWidget` and +:class:`FolderEditorWidget`. +""" + +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING, Any, cast + +from PySide6.QtCore import Qt, QThread +from PySide6.QtWidgets import ( + QCheckBox, + QComboBox, + QFrame, + QHBoxLayout, + QLabel, + QLineEdit, + QMessageBox, + QStackedWidget, + QTextEdit, + QVBoxLayout, + QWidget, +) + +from ui.request.auth.auth_field_specs import AUTH_FIELD_SPECS +from ui.request.auth.auth_pages import ( + AUTH_FIELD_ORDER, + AUTH_KEY_TO_DISPLAY, + AUTH_PAGE_INDEX, + AUTH_TYPE_DESCRIPTIONS, + AUTH_TYPE_KEYS, + AUTH_TYPE_LABELS, + AUTH_TYPES, + build_fields_page, + build_inherit_page, + build_noauth_page, +) +from ui.request.auth.auth_serializer import get_auth_fields, load_auth_fields +from ui.request.auth.oauth2_page import OAuth2Page + +if TYPE_CHECKING: + from PySide6.QtCore import QTimer + + from ui.widgets.variable_line_edit import VariableLineEdit + +logger = logging.getLogger(__name__) + + +class _AuthMixin: + """Mixin that adds auth tab building and auth data helpers. + + Expects the host class to provide :meth:`_on_field_changed` and + a ``_loading`` flag. Works with both :class:`RequestEditorWidget` + (which has ``_request_id``) and :class:`FolderEditorWidget` + (which has ``_collection_id``). + """ + + # -- Host-class interface (declared for type checkers) -------------- + _loading: bool + _debounce_timer: QTimer + + def _on_field_changed(self) -> None: ... + + # -- UI construction (called from host __init__) -------------------- + + def _build_auth_tab(self, auth_layout: QVBoxLayout) -> None: + """Construct the auth tab with Postman-style two-column layout. + + Left column: auth type selector + description text. + Right column: stacked field pages for the selected auth type. + """ + columns = QHBoxLayout() + columns.setSpacing(0) + columns.setContentsMargins(0, 0, 0, 0) + + # -- Left column: type picker + description ----------------------- + left = QWidget() + left.setFixedWidth(260) + left_layout = QVBoxLayout(left) + left_layout.setContentsMargins(0, 0, 16, 0) + + type_label = QLabel("Auth Type") + type_label.setObjectName("sectionLabel") + left_layout.addWidget(type_label) + + self._auth_type_combo = QComboBox() + self._auth_type_combo.addItems(list(AUTH_TYPES)) + self._auth_type_combo.currentTextChanged.connect(self._on_auth_type_changed) + left_layout.addWidget(self._auth_type_combo) + + left_layout.addSpacing(12) + + self._auth_description_label = QLabel() + self._auth_description_label.setObjectName("mutedLabel") + self._auth_description_label.setWordWrap(True) + self._auth_description_label.setAlignment( + Qt.AlignmentFlag.AlignLeft | Qt.AlignmentFlag.AlignTop + ) + self._auth_description_label.setText( + AUTH_TYPE_DESCRIPTIONS.get("Inherit auth from parent", "") + ) + left_layout.addWidget(self._auth_description_label) + + # Inherit preview sits below the description + self._inherit_preview_label = QLabel() + self._inherit_preview_label.setObjectName("sectionLabel") + self._inherit_preview_label.setWordWrap(True) + left_layout.addWidget(self._inherit_preview_label) + + left_layout.addStretch() + columns.addWidget(left) + + # -- Vertical separator ------------------------------------------- + sep = QFrame() + sep.setFrameShape(QFrame.Shape.VLine) + sep.setFrameShadow(QFrame.Shadow.Sunken) + columns.addWidget(sep) + + # -- Right column: stacked field pages ---------------------------- + self._auth_fields_stack = QStackedWidget() + self._auth_widget_map: dict[str, dict[str, QWidget]] = {} + + # 1. Inherit page (index 0) — just a preview label + inherit_page, _ = build_inherit_page() + self._auth_fields_stack.addWidget(inherit_page) + + # 2. No Auth page (index 1) — empty placeholder + self._auth_fields_stack.addWidget(build_noauth_page()) + + # 3. Field-based pages (bearer, basic, apikey, digest, ...) + # OAuth 2.0 gets a custom page instead of FieldSpec. + self._oauth2_page: OAuth2Page | None = None + for auth_key in AUTH_FIELD_ORDER: + if auth_key == "oauth2": + oauth2_page = OAuth2Page(self._on_field_changed) + oauth2_page.get_token_requested.connect(self._on_get_oauth2_token) + self._oauth2_page = oauth2_page + self._auth_fields_stack.addWidget(oauth2_page) + self._auth_widget_map[auth_key] = {} + else: + specs = AUTH_FIELD_SPECS.get(auth_key, ()) + page, widgets = build_fields_page(specs, self._on_field_changed) + self._auth_fields_stack.addWidget(page) + self._auth_widget_map[auth_key] = widgets + + # Backward-compat attributes used by existing tests + bw = self._auth_widget_map.get("bearer", {}) + self._bearer_token_input = cast("VariableLineEdit", bw.get("token")) + baw = self._auth_widget_map.get("basic", {}) + self._basic_username_input = cast("VariableLineEdit", baw.get("username")) + self._basic_password_input = cast("VariableLineEdit", baw.get("password")) + akw = self._auth_widget_map.get("apikey", {}) + self._apikey_key_input = cast("VariableLineEdit", akw.get("key")) + self._apikey_value_input = cast("VariableLineEdit", akw.get("value")) + self._apikey_add_to_combo = cast(QComboBox, akw.get("in")) + + # OAuth 2.0 worker state + self._oauth2_thread: QThread | None = None + + columns.addWidget(self._auth_fields_stack, 1) + auth_layout.addLayout(columns, 1) + + # -- Auth type switching ------------------------------------------- + + def _on_auth_type_changed(self, auth_type: str) -> None: + """Switch the stacked page, update description, and track changes.""" + idx = AUTH_PAGE_INDEX.get(auth_type, 0) + self._auth_fields_stack.setCurrentIndex(idx) + self._auth_description_label.setText(AUTH_TYPE_DESCRIPTIONS.get(auth_type, "")) + is_inherit = auth_type == "Inherit auth from parent" + self._inherit_preview_label.setVisible(is_inherit) + if is_inherit: + self._update_inherit_preview() + self._on_field_changed() + + # -- Inherit preview ----------------------------------------------- + + def _update_inherit_preview(self) -> None: + """Refresh the inherit page label with the resolved parent auth.""" + from services.collection_service import CollectionService + + request_id = getattr(self, "_request_id", None) + collection_id = getattr(self, "_collection_id", None) + if request_id: + resolved = CollectionService.get_request_inherited_auth(request_id) + elif collection_id: + resolved = CollectionService.get_collection_inherited_auth(collection_id) + else: + self._inherit_preview_label.setText("No parent auth configured.") + return + self._set_inherit_preview_from_auth(resolved) + + def _set_inherit_preview_from_auth(self, auth: dict[str, Any] | None) -> None: + """Set the inherit preview label from a resolved auth dict.""" + if not auth: + self._inherit_preview_label.setText("No parent auth configured.") + return + auth_type = auth.get("type", "") + label = AUTH_TYPE_LABELS.get(auth_type, auth_type) + self._inherit_preview_label.setText(f"Using {label} from parent.") + + # -- OAuth 2.0 token flow ------------------------------------------ + + def _on_get_oauth2_token(self) -> None: + """Launch the OAuth 2.0 token worker on a background thread.""" + if self._oauth2_page is None: + return + + config = self._oauth2_page.get_config() + if not config: + return + + from ui.request.http_worker import OAuth2TokenWorker + + worker = OAuth2TokenWorker() + worker.set_config(config) + + thread = QThread() + worker.moveToThread(thread) + thread.started.connect(worker.run) + worker.finished.connect(self._on_oauth2_token_received) + worker.error.connect(self._on_oauth2_token_error) + worker.finished.connect(thread.quit) + worker.error.connect(thread.quit) + thread.finished.connect(thread.deleteLater) + thread.finished.connect(worker.deleteLater) + + self._oauth2_thread = thread + thread.start() + + def _on_oauth2_token_received(self, data: dict) -> None: + """Store the obtained token in the OAuth 2.0 page.""" + if self._oauth2_page is None: + return + token = data.get("access_token", "") + name = self._oauth2_page.get_config().get("tokenName", "") + if token: + self._oauth2_page.set_token(token, str(name)) + self._on_field_changed() + + def _on_oauth2_token_error(self, msg: str) -> None: + """Show an error dialog when the token flow fails.""" + logger.error("OAuth 2.0 token error: %s", msg) + parent = self if isinstance(self, QWidget) else None + QMessageBox.warning(parent, "OAuth 2.0 Error", msg) + + # -- Load / save / clear ------------------------------------------- + + def _load_auth(self, auth: dict | None) -> None: + """Populate auth fields from a Postman-format auth dict. + + ``None`` or ``{}`` maps to *Inherit auth from parent*. + ``{"type": "noauth"}`` maps to *No Auth*. + """ + if not auth: + self._auth_type_combo.setCurrentText("Inherit auth from parent") + return + + auth_type = auth.get("type", "inherit") + display = AUTH_KEY_TO_DISPLAY.get(auth_type, "Inherit auth from parent") + self._auth_type_combo.setCurrentText(display) + + entries = auth.get(auth_type, []) + if auth_type == "oauth2" and self._oauth2_page is not None: + self._oauth2_page.load(entries) + else: + widgets = self._auth_widget_map.get(auth_type, {}) + if entries and widgets: + load_auth_fields(auth_type, widgets, entries) + + def _get_auth_data(self) -> dict | None: + """Build the auth configuration dict from the current UI state. + + Returns ``None`` for *Inherit auth from parent* (stored as + ``auth = None`` in the database). + """ + display_name = self._auth_type_combo.currentText() + if display_name == "Inherit auth from parent": + return None + if display_name == "No Auth": + return {"type": "noauth"} + + auth_key = AUTH_TYPE_KEYS.get(display_name) + if not auth_key: + return None + + if auth_key == "oauth2" and self._oauth2_page is not None: + entries = self._oauth2_page.get_entries() + else: + widgets = self._auth_widget_map.get(auth_key, {}) + entries = get_auth_fields(auth_key, widgets) + return {"type": auth_key, auth_key: entries} + + def _clear_auth(self) -> None: + """Reset the auth combo and all field widgets to defaults.""" + self._auth_type_combo.setCurrentText("Inherit auth from parent") + for widgets in self._auth_widget_map.values(): + for widget in widgets.values(): + if isinstance(widget, QLineEdit): + widget.clear() + elif isinstance(widget, QComboBox): + widget.setCurrentIndex(0) + elif isinstance(widget, QTextEdit): + widget.clear() + elif isinstance(widget, QCheckBox): + widget.setChecked(False) + if self._oauth2_page is not None: + self._oauth2_page.clear() diff --git a/src/ui/request/auth/auth_pages.py b/src/ui/request/auth/auth_pages.py new file mode 100644 index 0000000..c437e3a --- /dev/null +++ b/src/ui/request/auth/auth_pages.py @@ -0,0 +1,365 @@ +"""Auth-type constants, field specifications, and data-driven page builder. + +Defines the ordered list of supported auth types, human-readable labels, +stacked-widget page indices, per-type :class:`FieldSpec` descriptors, and +a generic :func:`build_fields_page` that constructs a Qt form page from a +sequence of field specs. +""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import TYPE_CHECKING + +from PySide6.QtCore import Qt +from PySide6.QtWidgets import ( + QCheckBox, + QComboBox, + QFormLayout, + QHBoxLayout, + QLabel, + QLineEdit, + QScrollArea, + QTextEdit, + QToolButton, + QVBoxLayout, + QWidget, +) + +from ui.widgets.variable_line_edit import VariableLineEdit + +if TYPE_CHECKING: + from collections.abc import Callable + +# --------------------------------------------------------------------------- +# Auth-type display names (stacked-widget page order) +# --------------------------------------------------------------------------- + +AUTH_TYPES: tuple[str, ...] = ( + "Inherit auth from parent", + "No Auth", + "Bearer Token", + "Basic Auth", + "API Key", + "Digest Auth", + "OAuth 1.0", + "OAuth 2.0", + "Hawk Authentication", + "AWS Signature", + "JWT Bearer", + "ASAP (Atlassian)", + "NTLM Authentication", + "Akamai EdgeGrid", +) + +# Postman type key -> display name (excludes inherit / noauth) +AUTH_TYPE_LABELS: dict[str, str] = { + "bearer": "Bearer Token", + "basic": "Basic Auth", + "apikey": "API Key", + "digest": "Digest Auth", + "oauth1": "OAuth 1.0", + "oauth2": "OAuth 2.0", + "hawk": "Hawk Authentication", + "awsv4": "AWS Signature", + "jwt": "JWT Bearer", + "asap": "ASAP (Atlassian)", + "ntlm": "NTLM Authentication", + "edgegrid": "Akamai EdgeGrid", +} + +# Display name -> Postman type key +AUTH_TYPE_KEYS: dict[str, str] = {v: k for k, v in AUTH_TYPE_LABELS.items()} + +# Postman type key -> display name (all types including inherit / noauth) +AUTH_KEY_TO_DISPLAY: dict[str, str] = { + "inherit": "Inherit auth from parent", + "noauth": "No Auth", + **AUTH_TYPE_LABELS, +} + +# Order of field-based pages in the stacked widget (after inherit=0, noauth=1) +AUTH_FIELD_ORDER: tuple[str, ...] = ( + "bearer", + "basic", + "apikey", + "digest", + "oauth1", + "oauth2", + "hawk", + "awsv4", + "jwt", + "asap", + "ntlm", + "edgegrid", +) + +# Display name -> stacked-widget page index +AUTH_PAGE_INDEX: dict[str, int] = { + "Inherit auth from parent": 0, + "No Auth": 1, +} +for _i, _key in enumerate(AUTH_FIELD_ORDER, start=2): + AUTH_PAGE_INDEX[AUTH_TYPE_LABELS[_key]] = _i + +# Short description shown in the left column for each auth type +AUTH_TYPE_DESCRIPTIONS: dict[str, str] = { + "Inherit auth from parent": ( + "This request will use the authorization configured on its parent collection or folder." + ), + "No Auth": "This request does not use any authorization.", + "Bearer Token": ( + "The authorization header will be automatically generated when you send the request." + ), + "Basic Auth": ( + "The authorization header will be automatically generated when you send the request." + ), + "API Key": ( + "The key-value pair will be added as a header or query parameter when you send the request." + ), + "Digest Auth": ( + "The authorization header will be automatically generated when you send the request." + ), + "OAuth 1.0": ( + "The authorization header will be automatically generated when you send the request." + ), + "OAuth 2.0": ( + "The authorization header will be automatically generated when you send the request." + ), + "Hawk Authentication": ( + "The authorization header will be automatically generated when you send the request." + ), + "AWS Signature": ( + "The authorization header will be automatically generated when you send the request." + ), + "JWT Bearer": ( + "The authorization header will be automatically generated when you send the request." + ), + "ASAP (Atlassian)": ( + "The authorization header will be automatically generated when you send the request." + ), + "NTLM Authentication": ( + "NTLM credentials will be used for Windows authentication when you send the request." + ), + "Akamai EdgeGrid": ( + "The authorization header will be automatically generated when you send the request." + ), +} + + +# --------------------------------------------------------------------------- +# Field specifications +# --------------------------------------------------------------------------- + + +@dataclass(frozen=True, slots=True) +class FieldSpec: + """Describes a single form field in an auth page. + + *kind* — ``"text"``, ``"password"``, ``"combo"``, or ``"textarea"``. + *combo_map* — Postman serialised value -> combo display text. When + empty the display text is used as-is for both load and save. + *save_as_bool* — When ``True`` the serialiser converts ``"true"`` + / ``"false"`` strings back to Python booleans on save. + *advanced* — When ``True`` the field is placed under a collapsible + "Advanced configuration" section, matching Postman's layout. + *suffix* — Optional suffix text displayed right of the input widget + (e.g. ``"bytes"``). + """ + + key: str + label: str + kind: str = "text" + placeholder: str = "" + options: tuple[str, ...] = () + combo_map: dict[str, str] = field(default_factory=dict) + default: str = "" + width: int | None = None + save_as_bool: bool = False + advanced: bool = False + suffix: str = "" + + +# --------------------------------------------------------------------------- +# Page builder +# --------------------------------------------------------------------------- + +_TEXTAREA_MAX_HEIGHT = 120 +_INPUT_MAX_WIDTH = 360 +_ADV_DESCRIPTION = ( + "Auto-generated default values are used for some of these fields unless a value is specified." +) + + +def _build_widget(spec: FieldSpec, on_change: Callable[[], None]) -> QWidget: + """Create the input widget for a single :class:`FieldSpec`.""" + if spec.kind == "combo": + w: QWidget = QComboBox() + assert isinstance(w, QComboBox) + w.addItems(list(spec.options)) + if spec.width: + w.setFixedWidth(spec.width) + else: + w.setMaximumWidth(_INPUT_MAX_WIDTH) + w.currentTextChanged.connect(on_change) + elif spec.kind == "checkbox": + w = QCheckBox(spec.label) + w.stateChanged.connect(lambda _: on_change()) + elif spec.kind == "password": + w = VariableLineEdit() + w.setPlaceholderText(spec.placeholder) + w.setEchoMode(QLineEdit.EchoMode.Password) + w.setMaximumWidth(_INPUT_MAX_WIDTH) + w.textChanged.connect(on_change) + elif spec.kind == "textarea": + w = QTextEdit() + assert isinstance(w, QTextEdit) + w.setPlaceholderText(spec.placeholder) + w.setMaximumHeight(_TEXTAREA_MAX_HEIGHT) + w.setMaximumWidth(_INPUT_MAX_WIDTH) + w.textChanged.connect(on_change) + else: + w = VariableLineEdit() + w.setPlaceholderText(spec.placeholder) + w.setMaximumWidth(_INPUT_MAX_WIDTH) + w.textChanged.connect(on_change) + return w + + +def _add_field_row( + form: QFormLayout, + spec: FieldSpec, + widget: QWidget, +) -> None: + """Add a label + widget row to *form*, handling suffix/checkbox.""" + if spec.kind == "checkbox": + form.addRow(widget) + return + lbl = QLabel(spec.label) + lbl.setObjectName("sectionLabel") + if spec.suffix: + row = QHBoxLayout() + row.setSpacing(6) + row.addWidget(widget) + suffix = QLabel(spec.suffix) + suffix.setObjectName("mutedLabel") + row.addWidget(suffix) + row.addStretch() + form.addRow(lbl, row) + else: + form.addRow(lbl, widget) + + +def build_fields_page( + specs: tuple[FieldSpec, ...], + on_change: Callable[[], None], +) -> tuple[QWidget, dict[str, QWidget]]: + """Build a form page from *specs* with an optional advanced section. + + Returns ``(page_widget, widgets_dict)`` where *widgets_dict* maps + each :attr:`FieldSpec.key` to the corresponding input widget. + + Primary fields appear at the top. If any spec has ``advanced=True``, + those fields are grouped under a collapsible "Advanced configuration" + toggle matching Postman's layout. + """ + primary = [s for s in specs if not s.advanced] + advanced = [s for s in specs if s.advanced] + + inner = QWidget() + root = QVBoxLayout(inner) + root.setContentsMargins(16, 8, 0, 0) + root.setSpacing(0) + + widgets: dict[str, QWidget] = {} + + # -- Primary fields ------------------------------------------------ + if primary: + primary_form = QFormLayout() + primary_form.setContentsMargins(0, 0, 0, 0) + primary_form.setHorizontalSpacing(12) + primary_form.setVerticalSpacing(10) + primary_form.setLabelAlignment(Qt.AlignmentFlag.AlignLeft | Qt.AlignmentFlag.AlignTop) + for spec in primary: + w = _build_widget(spec, on_change) + _add_field_row(primary_form, spec, w) + widgets[spec.key] = w + root.addLayout(primary_form) + + # -- Advanced section (collapsible) -------------------------------- + if advanced: + root.addSpacing(12) + + toggle = QToolButton() + toggle.setObjectName("advancedToggle") + toggle.setText("\u25b8 Advanced configuration") + toggle.setCheckable(True) + toggle.setChecked(False) + toggle.setToolButtonStyle(Qt.ToolButtonStyle.ToolButtonTextOnly) + toggle.setCursor(Qt.CursorShape.PointingHandCursor) + toggle.setStyleSheet("QToolButton { border: none; font-weight: bold; font-size: 12px; }") + root.addWidget(toggle) + + adv_container = QWidget() + adv_layout = QVBoxLayout(adv_container) + adv_layout.setContentsMargins(0, 4, 0, 0) + adv_layout.setSpacing(0) + + desc = QLabel(_ADV_DESCRIPTION) + desc.setObjectName("mutedLabel") + desc.setWordWrap(True) + adv_layout.addWidget(desc) + adv_layout.addSpacing(8) + + adv_form = QFormLayout() + adv_form.setContentsMargins(0, 0, 0, 0) + adv_form.setHorizontalSpacing(12) + adv_form.setVerticalSpacing(10) + adv_form.setLabelAlignment(Qt.AlignmentFlag.AlignLeft | Qt.AlignmentFlag.AlignTop) + for spec in advanced: + w = _build_widget(spec, on_change) + _add_field_row(adv_form, spec, w) + widgets[spec.key] = w + adv_layout.addLayout(adv_form) + + adv_container.setVisible(False) + root.addWidget(adv_container) + + def _on_toggle(checked: bool) -> None: + adv_container.setVisible(checked) + toggle.setText( + "\u25be Advanced configuration" if checked else "\u25b8 Advanced configuration" + ) + + toggle.toggled.connect(_on_toggle) + + root.addStretch() + + scroll = QScrollArea() + scroll.setWidgetResizable(True) + scroll.setFrameShape(QScrollArea.Shape.NoFrame) + scroll.setWidget(inner) + return scroll, widgets + + +def build_inherit_page() -> tuple[QWidget, QLabel]: + """Build the *Inherit auth from parent* right-side page. + + Returns ``(page_widget, preview_label)`` so the caller can update + the preview text as the active environment changes. + """ + page = QWidget() + layout = QVBoxLayout(page) + layout.setContentsMargins(0, 8, 0, 0) + preview = QLabel() + preview.setObjectName("sectionLabel") + preview.setWordWrap(True) + layout.addWidget(preview) + layout.addStretch() + return page, preview + + +def build_noauth_page() -> QWidget: + """Build the *No Auth* right-side page (empty placeholder).""" + page = QWidget() + return page diff --git a/src/ui/request/auth/auth_serializer.py b/src/ui/request/auth/auth_serializer.py new file mode 100644 index 0000000..c84ceaf --- /dev/null +++ b/src/ui/request/auth/auth_serializer.py @@ -0,0 +1,90 @@ +"""Auth field serialisation — load/save between Postman dicts and UI widgets. + +Uses the :data:`AUTH_FIELD_SPECS` registry from :mod:`auth_pages` +so that adding a new auth type requires **zero** changes here. +""" + +from __future__ import annotations + +from PySide6.QtWidgets import QCheckBox, QComboBox, QLineEdit, QTextEdit, QWidget + +from ui.request.auth.auth_field_specs import AUTH_FIELD_SPECS + + +def load_auth_fields( + auth_type: str, + widgets: dict[str, QWidget], + entries: list[dict], +) -> None: + """Populate *widgets* from a Postman key-value *entries* list. + + Each entry is ``{"key": "", "value": "", ...}``. + The *auth_type* selects the matching :class:`FieldSpec` list so + that combo-box mappings and defaults are applied correctly. + """ + specs = AUTH_FIELD_SPECS.get(auth_type, ()) + entry_map: dict[str, object] = { + e["key"]: e.get("value", "") for e in entries if isinstance(e, dict) + } + for spec in specs: + widget = widgets.get(spec.key) + if widget is None: + continue + raw = entry_map.get(spec.key, spec.default) + # Normalise booleans coming from Postman JSON + if isinstance(raw, bool): + value = "true" if raw else "false" + elif raw is None: + value = "" + else: + value = str(raw) + + if spec.kind == "combo" and isinstance(widget, QComboBox): + display = spec.combo_map.get(value, value) if spec.combo_map else value + widget.setCurrentText(display) + elif spec.kind == "checkbox" and isinstance(widget, QCheckBox): + widget.setChecked(value == "true") + elif spec.kind == "textarea" and isinstance(widget, QTextEdit): + widget.setPlainText(value) + elif isinstance(widget, QLineEdit): + widget.setText(value) + + +def get_auth_fields( + auth_type: str, + widgets: dict[str, QWidget], +) -> list[dict]: + """Serialise *widgets* into a Postman key-value entry list. + + Returns a list of ``{"key": ..., "value": ..., "type": "string"}`` + dicts ready for embedding in the auth configuration dict. + """ + specs = AUTH_FIELD_SPECS.get(auth_type, ()) + result: list[dict] = [] + for spec in specs: + widget = widgets.get(spec.key) + if widget is None: + continue + + if spec.kind == "combo" and isinstance(widget, QComboBox): + if spec.combo_map: + reverse = {v: k for k, v in spec.combo_map.items()} + raw_value: str | bool = reverse.get(widget.currentText(), widget.currentText()) + else: + raw_value = widget.currentText() + elif spec.kind == "checkbox" and isinstance(widget, QCheckBox): + raw_value = "true" if widget.isChecked() else "false" + elif spec.kind == "textarea" and isinstance(widget, QTextEdit): + raw_value = widget.toPlainText() + elif isinstance(widget, QLineEdit): + raw_value = widget.text() + else: + raw_value = "" + + # Convert "true"/"false" back to Python bools for Postman compat + if spec.save_as_bool and isinstance(raw_value, str) and raw_value in ("true", "false"): + result.append({"key": spec.key, "value": raw_value == "true", "type": "string"}) + else: + result.append({"key": spec.key, "value": raw_value, "type": "string"}) + return result + return result diff --git a/src/ui/request/auth/oauth2_page.py b/src/ui/request/auth/oauth2_page.py new file mode 100644 index 0000000..1e68b67 --- /dev/null +++ b/src/ui/request/auth/oauth2_page.py @@ -0,0 +1,462 @@ +"""Custom OAuth 2.0 page with grant-type switching and token management. + +Unlike FieldSpec-driven pages, this widget has specialised UI sections: + +1. **Current Token** — shows the active access token, header prefix, + and where to add it (header or query). +2. **Configure New Token** — grant-type selector with conditional + fields for Authorization Code, Implicit, Password Credentials, + and Client Credentials. +3. **Get New Access Token** button that triggers the OAuth flow. + +The page stores all values in Postman key-value format for seamless +round-tripping through :func:`load` / :func:`get_entries`. +""" + +from __future__ import annotations + +from collections.abc import Callable + +from PySide6.QtCore import Qt, Signal +from PySide6.QtWidgets import ( + QCheckBox, + QComboBox, + QFormLayout, + QHBoxLayout, + QLabel, + QLineEdit, + QPushButton, + QScrollArea, + QToolButton, + QVBoxLayout, + QWidget, +) + +from ui.widgets.variable_line_edit import VariableLineEdit + +_INPUT_MAX_WIDTH = 360 +_GRANT_TYPES = ( + "Authorization Code", + "Implicit", + "Password Credentials", + "Client Credentials", +) +_GRANT_TYPE_KEYS: dict[str, str] = { + "Authorization Code": "authorization_code", + "Implicit": "implicit", + "Password Credentials": "password", + "Client Credentials": "client_credentials", +} +_GRANT_KEY_TO_DISPLAY: dict[str, str] = {v: k for k, v in _GRANT_TYPE_KEYS.items()} + +_CLIENT_AUTH_OPTIONS = ( + "Send as Basic Auth header", + "Send client credentials in body", +) +_CLIENT_AUTH_KEYS: dict[str, str] = { + "Send as Basic Auth header": "header", + "Send client credentials in body": "body", +} +_CLIENT_AUTH_DISPLAY: dict[str, str] = {v: k for k, v in _CLIENT_AUTH_KEYS.items()} + + +class OAuth2Page(QScrollArea): + """Custom OAuth 2.0 configuration page. + + Signals: + field_changed: Emitted when any field value changes. + get_token_requested: Emitted when user clicks *Get New Access Token*. + """ + + field_changed = Signal() + get_token_requested = Signal() + + def __init__(self, on_change: Callable[[], None]) -> None: + """Build the OAuth 2.0 page with token and configuration sections.""" + super().__init__() + self._initializing = True + self._on_change = on_change + self.setWidgetResizable(True) + self.setFrameShape(QScrollArea.Shape.NoFrame) + + inner = QWidget() + root = QVBoxLayout(inner) + root.setContentsMargins(16, 8, 0, 0) + root.setSpacing(0) + + self._build_current_token_section(root) + root.addSpacing(16) + self._build_configure_section(root) + root.addSpacing(12) + self._build_get_token_button(root) + root.addStretch() + + self.setWidget(inner) + self._initializing = False + + # ------------------------------------------------------------------ + # Section builders + # ------------------------------------------------------------------ + + def _build_current_token_section(self, root: QVBoxLayout) -> None: + """Build the *Current Token* section.""" + header = QLabel("Current Token") + header.setObjectName("sectionLabel") + header.setStyleSheet("font-weight: bold; font-size: 12px;") + root.addWidget(header) + root.addSpacing(8) + + form = QFormLayout() + form.setContentsMargins(0, 0, 0, 0) + form.setHorizontalSpacing(12) + form.setVerticalSpacing(10) + form.setLabelAlignment(Qt.AlignmentFlag.AlignLeft | Qt.AlignmentFlag.AlignTop) + + self._token_name_display = VariableLineEdit() + self._token_name_display.setPlaceholderText("Token name") + self._token_name_display.setMaximumWidth(_INPUT_MAX_WIDTH) + self._token_name_display.setReadOnly(True) + _add_row(form, "Token", self._token_name_display) + + self._access_token = VariableLineEdit() + self._access_token.setPlaceholderText("Paste or obtain a token") + self._access_token.setMaximumWidth(_INPUT_MAX_WIDTH) + self._access_token.textChanged.connect(self._on_change) + _add_row(form, "Access Token", self._access_token) + + self._header_prefix = VariableLineEdit() + self._header_prefix.setPlaceholderText("Bearer") + self._header_prefix.setMaximumWidth(_INPUT_MAX_WIDTH) + self._header_prefix.setText("Bearer") + self._header_prefix.textChanged.connect(self._on_change) + _add_row(form, "Header Prefix", self._header_prefix) + + self._add_token_to = QComboBox() + self._add_token_to.addItems(("Header", "Query Params")) + self._add_token_to.setMaximumWidth(_INPUT_MAX_WIDTH) + self._add_token_to.currentTextChanged.connect(lambda _: self._on_change()) + _add_row(form, "Add token to", self._add_token_to) + + root.addLayout(form) + + def _build_configure_section(self, root: QVBoxLayout) -> None: + """Build the *Configure New Token* section with grant-type switching.""" + toggle = QToolButton() + toggle.setObjectName("advancedToggle") + toggle.setText("\u25b8 Configure New Token") + toggle.setCheckable(True) + toggle.setChecked(False) + toggle.setToolButtonStyle(Qt.ToolButtonStyle.ToolButtonTextOnly) + toggle.setCursor(Qt.CursorShape.PointingHandCursor) + toggle.setStyleSheet("QToolButton { border: none; font-weight: bold; font-size: 12px; }") + root.addWidget(toggle) + + self._config_container = QWidget() + config_layout = QVBoxLayout(self._config_container) + config_layout.setContentsMargins(0, 4, 0, 0) + config_layout.setSpacing(0) + + # Token Name + Grant Type + top_form = QFormLayout() + top_form.setContentsMargins(0, 0, 0, 0) + top_form.setHorizontalSpacing(12) + top_form.setVerticalSpacing(10) + top_form.setLabelAlignment(Qt.AlignmentFlag.AlignLeft | Qt.AlignmentFlag.AlignTop) + + self._token_name = VariableLineEdit() + self._token_name.setPlaceholderText("Token Name") + self._token_name.setMaximumWidth(_INPUT_MAX_WIDTH) + self._token_name.textChanged.connect(self._on_change) + _add_row(top_form, "Token Name", self._token_name) + + self._grant_type = QComboBox() + self._grant_type.addItems(_GRANT_TYPES) + self._grant_type.setMaximumWidth(_INPUT_MAX_WIDTH) + self._grant_type.currentTextChanged.connect(self._on_grant_type_changed) + _add_row(top_form, "Grant Type", self._grant_type) + + config_layout.addLayout(top_form) + config_layout.addSpacing(8) + + # Grant-type specific fields + self._grant_fields: dict[str, dict[str, QWidget]] = {} + self._grant_containers: dict[str, QWidget] = {} + for display_name in _GRANT_TYPES: + container, widgets = self._build_grant_fields(display_name) + self._grant_containers[display_name] = container + self._grant_fields[display_name] = widgets + config_layout.addWidget(container) + + self._config_container.setVisible(False) + root.addWidget(self._config_container) + + def _on_toggle(checked: bool) -> None: + self._config_container.setVisible(checked) + toggle.setText( + "\u25be Configure New Token" if checked else "\u25b8 Configure New Token" + ) + + toggle.toggled.connect(_on_toggle) + self._on_grant_type_changed(self._grant_type.currentText()) + + def _build_grant_fields(self, grant_display: str) -> tuple[QWidget, dict[str, QWidget]]: + """Build the conditional field group for a single grant type.""" + container = QWidget() + layout = QVBoxLayout(container) + layout.setContentsMargins(0, 0, 0, 0) + layout.setSpacing(0) + + form = QFormLayout() + form.setContentsMargins(0, 0, 0, 0) + form.setHorizontalSpacing(12) + form.setVerticalSpacing(10) + form.setLabelAlignment(Qt.AlignmentFlag.AlignLeft | Qt.AlignmentFlag.AlignTop) + + widgets: dict[str, QWidget] = {} + + if grant_display == "Authorization Code": + widgets["callbackUrl"] = _text(form, "Callback URL", "https://localhost:5000/callback") + widgets["useBrowser"] = _checkbox(form, "Authorize using browser") + widgets["authUrl"] = _text(form, "Auth URL", "https://example.com/authorize") + widgets["accessTokenUrl"] = _text(form, "Access Token URL", "https://example.com/token") + widgets["clientId"] = _text(form, "Client ID", "Client ID") + widgets["clientSecret"] = _password(form, "Client Secret", "Client secret") + widgets["scope"] = _text(form, "Scope", "read write") + widgets["state"] = _text(form, "State", "random_state") + widgets["client_authentication"] = _auth_combo(form) + + elif grant_display == "Implicit": + widgets["callbackUrl"] = _text(form, "Callback URL", "https://localhost:5000/callback") + widgets["authUrl"] = _text(form, "Auth URL", "https://example.com/authorize") + widgets["clientId"] = _text(form, "Client ID", "Client ID") + widgets["scope"] = _text(form, "Scope", "read write") + widgets["state"] = _text(form, "State", "random_state") + + elif grant_display == "Password Credentials": + widgets["accessTokenUrl"] = _text(form, "Access Token URL", "https://example.com/token") + widgets["username"] = _text(form, "Username", "Username") + widgets["password"] = _password(form, "Password", "Password") + widgets["clientId"] = _text(form, "Client ID", "Client ID") + widgets["clientSecret"] = _password(form, "Client Secret", "Client secret") + widgets["scope"] = _text(form, "Scope", "read write") + widgets["client_authentication"] = _auth_combo(form) + + elif grant_display == "Client Credentials": + widgets["accessTokenUrl"] = _text(form, "Access Token URL", "https://example.com/token") + widgets["clientId"] = _text(form, "Client ID", "Client ID") + widgets["clientSecret"] = _password(form, "Client Secret", "Client secret") + widgets["scope"] = _text(form, "Scope", "read write") + widgets["client_authentication"] = _auth_combo(form) + + layout.addLayout(form) + + # Connect change signals + for w in widgets.values(): + if isinstance(w, QLineEdit): + w.textChanged.connect(self._on_change) + elif isinstance(w, QComboBox): + w.currentTextChanged.connect(lambda _: self._on_change()) + elif isinstance(w, QCheckBox): + w.stateChanged.connect(lambda _: self._on_change()) + + return container, widgets + + def _build_get_token_button(self, root: QVBoxLayout) -> None: + """Create the *Get New Access Token* button.""" + row = QHBoxLayout() + row.setContentsMargins(0, 0, 0, 0) + self._get_token_btn = QPushButton("Get New Access Token") + self._get_token_btn.setCursor(Qt.CursorShape.PointingHandCursor) + self._get_token_btn.setMaximumWidth(200) + self._get_token_btn.clicked.connect(self.get_token_requested.emit) + row.addWidget(self._get_token_btn) + row.addStretch() + root.addLayout(row) + + # ------------------------------------------------------------------ + # Grant type switching + # ------------------------------------------------------------------ + + def _on_grant_type_changed(self, display_name: str) -> None: + """Show fields for the selected grant type, hide others.""" + for name, container in self._grant_containers.items(): + container.setVisible(name == display_name) + if not self._initializing: + self._on_change() + + # ------------------------------------------------------------------ + # Load / Save (Postman key-value format) + # ------------------------------------------------------------------ + + def load(self, entries: list[dict]) -> None: + """Populate the page from Postman key-value *entries*.""" + entry_map: dict[str, object] = {} + for e in entries: + if isinstance(e, dict): + entry_map[e["key"]] = e.get("value", "") + + # Current token section + _set_text(self._access_token, str(entry_map.get("accessToken", ""))) + _set_text(self._header_prefix, str(entry_map.get("headerPrefix", "Bearer"))) + _set_text(self._token_name, str(entry_map.get("tokenName", ""))) + _set_text(self._token_name_display, str(entry_map.get("tokenName", ""))) + + add_to = str(entry_map.get("addTokenTo", "header")) + add_to_map = {"header": "Header", "queryParams": "Query Params"} + self._add_token_to.setCurrentText(add_to_map.get(add_to, "Header")) + + # Grant type + grant_raw = str(entry_map.get("grant_type", "authorization_code")) + grant_display = _GRANT_KEY_TO_DISPLAY.get(grant_raw, "Authorization Code") + self._grant_type.setCurrentText(grant_display) + + # Grant-specific fields (load into ALL containers — widgets may overlap) + for _display_name, widgets in self._grant_fields.items(): + for key, widget in widgets.items(): + raw = entry_map.get(key, "") + if isinstance(widget, QCheckBox): + widget.setChecked(raw is True or str(raw).lower() == "true") + elif isinstance(widget, QComboBox): + if key == "client_authentication": + display = _CLIENT_AUTH_DISPLAY.get(str(raw), str(raw)) + widget.setCurrentText(display) + else: + widget.setCurrentText(str(raw)) + elif isinstance(widget, QLineEdit): + _set_text(widget, str(raw) if raw else "") + + def get_entries(self) -> list[dict]: + """Serialise all field values to Postman key-value entry list.""" + entries: list[dict] = [] + + def _add(key: str, value: str | bool) -> None: + entries.append({"key": key, "value": value, "type": "string"}) + + # Current token + _add("accessToken", self._access_token.text()) + _add("headerPrefix", self._header_prefix.text() or "Bearer") + _add("tokenName", self._token_name.text()) + + add_to_rev = {"Header": "header", "Query Params": "queryParams"} + _add("addTokenTo", add_to_rev.get(self._add_token_to.currentText(), "header")) + + # Grant type + grant_display = self._grant_type.currentText() + grant_key = _GRANT_TYPE_KEYS.get(grant_display, "authorization_code") + _add("grant_type", grant_key) + + # Active grant fields only + active_widgets = self._grant_fields.get(grant_display, {}) + for key, widget in active_widgets.items(): + if isinstance(widget, QCheckBox): + _add(key, widget.isChecked()) + elif isinstance(widget, QComboBox): + if key == "client_authentication": + _add(key, _CLIENT_AUTH_KEYS.get(widget.currentText(), "header")) + else: + _add(key, widget.currentText()) + elif isinstance(widget, QLineEdit): + _add(key, widget.text()) + + return entries + + def get_config(self) -> dict: + """Return configuration needed for the token flow. + + Used by the token worker to know which grant type and endpoints + to use. + """ + grant_display = self._grant_type.currentText() + grant_key = _GRANT_TYPE_KEYS.get(grant_display, "authorization_code") + active = self._grant_fields.get(grant_display, {}) + + config: dict[str, str | bool] = {"grant_type": grant_key} + for key, widget in active.items(): + if isinstance(widget, QCheckBox): + config[key] = widget.isChecked() + elif isinstance(widget, QComboBox): + if key == "client_authentication": + config[key] = _CLIENT_AUTH_KEYS.get(widget.currentText(), "header") + else: + config[key] = widget.currentText() + elif isinstance(widget, QLineEdit): + config[key] = widget.text() + config["tokenName"] = self._token_name.text() + return config + + def set_token(self, token: str, name: str = "") -> None: + """Set the obtained token into the current-token section.""" + self._access_token.setText(token) + if name: + _set_text(self._token_name_display, name) + + def clear(self) -> None: + """Reset all fields to defaults.""" + self._access_token.clear() + self._header_prefix.setText("Bearer") + self._add_token_to.setCurrentIndex(0) + self._token_name.clear() + self._token_name_display.clear() + self._grant_type.setCurrentIndex(0) + for widgets in self._grant_fields.values(): + for w in widgets.values(): + if isinstance(w, QLineEdit): + w.clear() + elif isinstance(w, QComboBox): + w.setCurrentIndex(0) + elif isinstance(w, QCheckBox): + w.setChecked(False) + + +# ------------------------------------------------------------------ +# Internal helpers +# ------------------------------------------------------------------ + + +def _add_row(form: QFormLayout, label_text: str, widget: QWidget) -> None: + """Add a labelled row to *form*.""" + lbl = QLabel(label_text) + lbl.setObjectName("sectionLabel") + form.addRow(lbl, widget) + + +def _text(form: QFormLayout, label: str, placeholder: str) -> VariableLineEdit: + """Create a text input and add it to *form*.""" + w = VariableLineEdit() + w.setPlaceholderText(placeholder) + w.setMaximumWidth(_INPUT_MAX_WIDTH) + _add_row(form, label, w) + return w + + +def _password(form: QFormLayout, label: str, placeholder: str) -> VariableLineEdit: + """Create a password input and add it to *form*.""" + w = VariableLineEdit() + w.setPlaceholderText(placeholder) + w.setEchoMode(QLineEdit.EchoMode.Password) + w.setMaximumWidth(_INPUT_MAX_WIDTH) + _add_row(form, label, w) + return w + + +def _checkbox(form: QFormLayout, label: str) -> QCheckBox: + """Create a checkbox and add it to *form* (full-row).""" + cb = QCheckBox(label) + form.addRow(cb) + return cb + + +def _auth_combo(form: QFormLayout) -> QComboBox: + """Create a *Client Authentication* combo and add it to *form*.""" + w = QComboBox() + w.addItems(_CLIENT_AUTH_OPTIONS) + w.setMaximumWidth(_INPUT_MAX_WIDTH) + _add_row(form, "Client Authentication", w) + return w + + +def _set_text(widget: QLineEdit, text: str) -> None: + """Set text on a QLineEdit without triggering change signals.""" + widget.blockSignals(True) + widget.setText(text) + widget.blockSignals(False) diff --git a/src/ui/request/folder_editor.py b/src/ui/request/folder_editor.py index 36e94ff..9536d1d 100644 --- a/src/ui/request/folder_editor.py +++ b/src/ui/request/folder_editor.py @@ -11,24 +11,19 @@ from PySide6.QtCore import Qt, QTimer, Signal from PySide6.QtWidgets import ( - QComboBox, QFrame, QHBoxLayout, QLabel, - QLineEdit, - QStackedWidget, QTabWidget, QTextEdit, QVBoxLayout, QWidget, ) +from ui.request.auth import _AuthMixin from ui.widgets.code_editor import CodeEditorWidget from ui.widgets.key_value_table import KeyValueTableWidget -# Authorization type identifiers (same as RequestEditorWidget) -_AUTH_TYPES = ("No Auth", "Bearer Token", "Basic Auth", "API Key") - # Debounce delay (ms) for the collection_changed signal _DEBOUNCE_MS = 800 @@ -64,7 +59,7 @@ def _normalize_events(events: Any) -> dict[str, str]: return {} -class FolderEditorWidget(QWidget): +class FolderEditorWidget(_AuthMixin, QWidget): """Editable folder detail view with Overview, Auth, Scripts, and Variables. Call :meth:`load_collection` to populate the pane from a collection dict. @@ -157,95 +152,7 @@ def __init__(self, parent: QWidget | None = None) -> None: auth_layout = QVBoxLayout(self._auth_tab) auth_layout.setContentsMargins(0, 6, 0, 0) - auth_type_row = QHBoxLayout() - auth_type_row.setSpacing(8) - auth_type_label = QLabel("Type:") - auth_type_label.setObjectName("sectionLabel") - auth_type_row.addWidget(auth_type_label) - - self._auth_type_combo = QComboBox() - self._auth_type_combo.addItems(list(_AUTH_TYPES)) - self._auth_type_combo.setFixedWidth(140) - self._auth_type_combo.currentTextChanged.connect(self._on_auth_type_changed) - auth_type_row.addWidget(self._auth_type_combo) - auth_type_row.addStretch() - auth_layout.addLayout(auth_type_row) - - # Auth fields stack - self._auth_fields_stack = QStackedWidget() - - # No Auth page - no_auth_page = QLabel("This folder does not use any authorization.") - no_auth_page.setAlignment(Qt.AlignmentFlag.AlignCenter) - no_auth_page.setObjectName("emptyStateLabel") - self._auth_fields_stack.addWidget(no_auth_page) - - # Bearer Token page - bearer_page = QWidget() - bearer_layout = QVBoxLayout(bearer_page) - bearer_layout.setContentsMargins(0, 8, 0, 0) - token_label = QLabel("Token") - token_label.setObjectName("sectionLabel") - bearer_layout.addWidget(token_label) - self._bearer_token_input = QLineEdit() - self._bearer_token_input.setPlaceholderText("Enter bearer token") - self._bearer_token_input.textChanged.connect(self._on_field_changed) - bearer_layout.addWidget(self._bearer_token_input) - bearer_layout.addStretch() - self._auth_fields_stack.addWidget(bearer_page) - - # Basic Auth page - basic_page = QWidget() - basic_layout = QVBoxLayout(basic_page) - basic_layout.setContentsMargins(0, 8, 0, 0) - username_label = QLabel("Username") - username_label.setObjectName("sectionLabel") - basic_layout.addWidget(username_label) - self._basic_username_input = QLineEdit() - self._basic_username_input.setPlaceholderText("Username") - self._basic_username_input.textChanged.connect(self._on_field_changed) - basic_layout.addWidget(self._basic_username_input) - password_label = QLabel("Password") - password_label.setObjectName("sectionLabel") - basic_layout.addWidget(password_label) - self._basic_password_input = QLineEdit() - self._basic_password_input.setPlaceholderText("Password") - self._basic_password_input.setEchoMode(QLineEdit.EchoMode.Password) - self._basic_password_input.textChanged.connect(self._on_field_changed) - basic_layout.addWidget(self._basic_password_input) - basic_layout.addStretch() - self._auth_fields_stack.addWidget(basic_page) - - # API Key page - apikey_page = QWidget() - apikey_layout = QVBoxLayout(apikey_page) - apikey_layout.setContentsMargins(0, 8, 0, 0) - key_label = QLabel("Key") - key_label.setObjectName("sectionLabel") - apikey_layout.addWidget(key_label) - self._apikey_key_input = QLineEdit() - self._apikey_key_input.setPlaceholderText("Header or query parameter name") - self._apikey_key_input.textChanged.connect(self._on_field_changed) - apikey_layout.addWidget(self._apikey_key_input) - value_label = QLabel("Value") - value_label.setObjectName("sectionLabel") - apikey_layout.addWidget(value_label) - self._apikey_value_input = QLineEdit() - self._apikey_value_input.setPlaceholderText("API key value") - self._apikey_value_input.textChanged.connect(self._on_field_changed) - apikey_layout.addWidget(self._apikey_value_input) - add_to_label = QLabel("Add to") - add_to_label.setObjectName("sectionLabel") - apikey_layout.addWidget(add_to_label) - self._apikey_add_to_combo = QComboBox() - self._apikey_add_to_combo.addItems(["Header", "Query Params"]) - self._apikey_add_to_combo.setFixedWidth(140) - self._apikey_add_to_combo.currentTextChanged.connect(self._on_field_changed) - apikey_layout.addWidget(self._apikey_add_to_combo) - apikey_layout.addStretch() - self._auth_fields_stack.addWidget(apikey_page) - - auth_layout.addWidget(self._auth_fields_stack, 1) + self._build_auth_tab(auth_layout) self._tabs.addTab(self._auth_tab, "Authorization") # ---- Scripts tab ---- @@ -351,7 +258,7 @@ def load_collection( self._description_edit.setPlainText(data.get("description") or "") # Auth - self._load_auth(data.get("auth") or {}) + self._load_auth(data.get("auth")) # Scripts (events -- accept both dict and Postman list format) events = _normalize_events(data.get("events")) @@ -377,49 +284,6 @@ def load_collection( finally: self._loading = False - def _load_auth(self, auth: dict) -> None: - """Populate the auth UI from a Postman-style auth dict.""" - auth_type = auth.get("type", "noauth") - auth_type_map = { - "noauth": "No Auth", - "bearer": "Bearer Token", - "basic": "Basic Auth", - "apikey": "API Key", - } - self._auth_type_combo.setCurrentText(auth_type_map.get(auth_type, "No Auth")) - if auth_type == "bearer": - bearer_list = auth.get("bearer", []) - token = "" - for entry in bearer_list: - if entry.get("key") == "token": - token = entry.get("value", "") - self._bearer_token_input.setText(token) - elif auth_type == "basic": - basic_list = auth.get("basic", []) - username = password = "" - for entry in basic_list: - if entry.get("key") == "username": - username = entry.get("value", "") - elif entry.get("key") == "password": - password = entry.get("value", "") - self._basic_username_input.setText(username) - self._basic_password_input.setText(password) - elif auth_type == "apikey": - apikey_list = auth.get("apikey", []) - key = value = "" - add_to = "header" - for entry in apikey_list: - if entry.get("key") == "key": - key = entry.get("value", "") - elif entry.get("key") == "value": - value = entry.get("value", "") - elif entry.get("key") == "in": - add_to = entry.get("value", "header") - self._apikey_key_input.setText(key) - self._apikey_value_input.setText(value) - add_to_map = {"header": "Header", "query": "Query Params"} - self._apikey_add_to_combo.setCurrentText(add_to_map.get(add_to, "Header")) - def get_collection_data(self) -> dict: """Return the current editor state as a dict suitable for saving.""" return { @@ -441,13 +305,7 @@ def clear(self) -> None: self._updated_label.setText("") self._recent_requests_label.setText("") self._description_edit.clear() - self._auth_type_combo.setCurrentText("No Auth") - self._bearer_token_input.clear() - self._basic_username_input.clear() - self._basic_password_input.clear() - self._apikey_key_input.clear() - self._apikey_value_input.clear() - self._apikey_add_to_combo.setCurrentIndex(0) + self._clear_auth() self._pre_request_edit.clear() self._test_script_edit.clear() self._variables_table.set_data([]) @@ -474,76 +332,6 @@ def _load_recent_requests(self, requests: list[dict[str, Any]]) -> None: lines.append(f"{method} {name}{ts_str}") self._recent_requests_label.setText("\n".join(lines)) - # -- Auth helpers -------------------------------------------------- - - def _on_auth_type_changed(self, auth_type: str) -> None: - """Switch the auth fields stack page based on the selected type.""" - page_map = { - "No Auth": 0, - "Bearer Token": 1, - "Basic Auth": 2, - "API Key": 3, - } - self._auth_fields_stack.setCurrentIndex(page_map.get(auth_type, 0)) - if not self._loading: - self._debounce_timer.start() - - def _get_auth_data(self) -> dict | None: - """Build the auth configuration dict from the current UI state.""" - auth_type = self._auth_type_combo.currentText() - if auth_type == "No Auth": - return {"type": "noauth"} - if auth_type == "Bearer Token": - return { - "type": "bearer", - "bearer": [ - { - "key": "token", - "value": self._bearer_token_input.text(), - "type": "string", - }, - ], - } - if auth_type == "Basic Auth": - return { - "type": "basic", - "basic": [ - { - "key": "username", - "value": self._basic_username_input.text(), - "type": "string", - }, - { - "key": "password", - "value": self._basic_password_input.text(), - "type": "string", - }, - ], - } - if auth_type == "API Key": - add_to = "header" if self._apikey_add_to_combo.currentText() == "Header" else "query" - return { - "type": "apikey", - "apikey": [ - { - "key": "key", - "value": self._apikey_key_input.text(), - "type": "string", - }, - { - "key": "value", - "value": self._apikey_value_input.text(), - "type": "string", - }, - { - "key": "in", - "value": add_to, - "type": "string", - }, - ], - } - return None - # -- Events / scripts helpers -------------------------------------- def _get_events_data(self) -> dict | None: diff --git a/src/ui/request/http_worker.py b/src/ui/request/http_worker.py index fbd0cba..1449f76 100644 --- a/src/ui/request/http_worker.py +++ b/src/ui/request/http_worker.py @@ -149,7 +149,14 @@ def run(self) -> None: # 3. Apply auth configuration if self._auth_data: - url, headers = self._apply_auth(self._auth_data, url, headers, variables) + url, headers = self._apply_auth( + self._auth_data, + url, + headers, + variables, + method=self._method, + body=body, + ) result: HttpResponseDict = HttpService.send_request( method=self._method, @@ -175,61 +182,52 @@ def _apply_auth( url: str, headers: str | None, variables: dict[str, str], + *, + method: str = "GET", + body: str | None = None, ) -> tuple[str, str | None]: """Inject auth credentials into the URL or headers. + Substitutes environment variables in auth entry values, then + delegates to :func:`services.http.auth_handler.apply_auth`. Returns the (possibly modified) ``url`` and ``headers``. """ if not auth_data: return url, headers from services.environment_service import EnvironmentService + from services.http.auth_handler import apply_auth + from services.http.header_utils import parse_header_dict - auth_type = auth_data.get("type", "noauth") sub = EnvironmentService.substitute + auth_type = auth_data.get("type", "noauth") - if auth_type == "bearer": - token = "" - for entry in auth_data.get("bearer", []): - if entry.get("key") == "token": - token = sub(entry.get("value", ""), variables) - if token: - auth_line = f"Authorization: Bearer {token}" - headers = f"{headers}\n{auth_line}" if headers else auth_line - - elif auth_type == "basic": - import base64 - - username = password = "" - for entry in auth_data.get("basic", []): - if entry.get("key") == "username": - username = sub(entry.get("value", ""), variables) - elif entry.get("key") == "password": - password = sub(entry.get("value", ""), variables) - if username or password: - encoded = base64.b64encode(f"{username}:{password}".encode()).decode() - auth_line = f"Authorization: Basic {encoded}" - headers = f"{headers}\n{auth_line}" if headers else auth_line - - elif auth_type == "apikey": - key = value = "" - add_to = "header" - for entry in auth_data.get("apikey", []): - if entry.get("key") == "key": - key = sub(entry.get("value", ""), variables) - elif entry.get("key") == "value": - value = sub(entry.get("value", ""), variables) - elif entry.get("key") == "in": - add_to = entry.get("value", "header") - if key and value: - if add_to == "header": - auth_line = f"{key}: {value}" - headers = f"{headers}\n{auth_line}" if headers else auth_line - else: - sep = "&" if "?" in url else "?" - url = f"{url}{sep}{key}={value}" - - return url, headers + # Substitute variables in entry values (shallow copy to avoid mutation) + entries = auth_data.get(auth_type, []) + if entries and variables: + substituted = dict(auth_data) + substituted[auth_type] = [ + {**e, "value": sub(str(e.get("value", "")), variables)} + if isinstance(e, dict) + else e + for e in entries + ] + else: + substituted = auth_data + + # Convert header string to dict, apply auth, convert back + hdr_dict = parse_header_dict(headers) + url, hdr_dict = apply_auth( + substituted, + url, + hdr_dict, + method=method, + body=body, + ) + new_headers: str | None = ( + "\n".join(f"{k}: {v}" for k, v in hdr_dict.items()) if hdr_dict else None + ) + return url, new_headers class SchemaFetchWorker(QObject): @@ -284,3 +282,47 @@ def run(self) -> None: except Exception as exc: logger.exception("Schema fetch worker failed") self.error.emit(str(exc)) + + +class OAuth2TokenWorker(QObject): + """Execute an OAuth 2.0 token flow on a background thread. + + Set configuration via :meth:`set_config` **before** calling + ``moveToThread()``. Connect ``finished`` and ``error`` signals, + then start the owning ``QThread``. + + Signals: + finished(dict): Emitted with an :class:`OAuth2TokenResult` on success. + error(str): Emitted with an error message on failure. + """ + + finished = Signal(dict) + error = Signal(str) + + def __init__(self) -> None: + """Initialise with empty configuration.""" + super().__init__() + self._config: dict = {} + + def set_config(self, config: dict) -> None: + """Configure the OAuth 2.0 flow parameters. + + Must be called **before** the worker is moved to its thread. + """ + self._config = config + + @Slot() + def run(self) -> None: + """Perform the token exchange and emit the result signal.""" + try: + from services.http.oauth2_service import OAuth2Service + + result = OAuth2Service.get_token(self._config) + if result.get("error"): + self.error.emit(str(result["error"])) + else: + self.finished.emit(dict(result)) + except Exception as exc: + logger.exception("OAuth 2.0 token worker failed") + self.error.emit(str(exc)) + self.error.emit(str(exc)) diff --git a/src/ui/request/navigation/request_tab_bar.py b/src/ui/request/navigation/request_tab_bar.py index f94379a..e0a9cf2 100644 --- a/src/ui/request/navigation/request_tab_bar.py +++ b/src/ui/request/navigation/request_tab_bar.py @@ -322,11 +322,20 @@ def tab_label(self, index: int) -> _TabLabel | None: def mouseDoubleClickEvent(self, event: QMouseEvent) -> None: """Emit double-click signal for tab promotion.""" - index = self.tabAt(event.pos()) + index = self.tabAt(event.position().toPoint()) if index >= 0: self.tab_double_clicked.emit(index) super().mouseDoubleClickEvent(event) + def mousePressEvent(self, event: QMouseEvent) -> None: + """Close tab on middle-click.""" + if event.button() == Qt.MouseButton.MiddleButton: + index = self.tabAt(event.position().toPoint()) + if index >= 0: + self.tab_close_requested.emit(index) + return + super().mousePressEvent(event) + def contextMenuEvent(self, event: QMouseEvent) -> None: # type: ignore[override] """Show right-click context menu with Close / Close Others / Close All.""" index = self.tabAt(event.pos()) diff --git a/src/ui/request/navigation/tab_manager.py b/src/ui/request/navigation/tab_manager.py index 7c7846c..2b1376e 100644 --- a/src/ui/request/navigation/tab_manager.py +++ b/src/ui/request/navigation/tab_manager.py @@ -43,6 +43,9 @@ class TabContext: is_dirty: Whether the editor has unsaved changes. is_sending: Whether an HTTP request is currently in flight. is_preview: Whether this tab is in preview mode (temporary). + draft_name: Display name for unsaved draft tabs. ``None`` for + persisted requests. Updated when the user renames via the + breadcrumb bar. local_overrides: Transient per-tab variable overrides. When the user edits a variable value in the popup without clicking **Update**, the override is stored here and @@ -76,6 +79,7 @@ def __init__( self.is_dirty: bool = False self.is_sending: bool = False self.is_preview: bool = is_preview + self.draft_name: str | None = None self.local_overrides: dict[str, LocalOverride] = {} # -- Send lifecycle ------------------------------------------------ diff --git a/src/ui/request/request_editor/auth.py b/src/ui/request/request_editor/auth.py index 9d800c0..df30860 100644 --- a/src/ui/request/request_editor/auth.py +++ b/src/ui/request/request_editor/auth.py @@ -1,241 +1,12 @@ -"""Authorization tab mixin for the request editor. +"""Authorization tab mixin re-export. -Provides ``_AuthMixin`` with the auth tab UI construction -(type selector and field pages for Bearer, Basic, API Key) and -serialisation / deserialisation helpers. Mixed into -``RequestEditorWidget``. +This module re-exports :class:`_AuthMixin` from the shared +``ui.request.auth`` sub-package so that existing imports +continue to work. """ from __future__ import annotations -from typing import TYPE_CHECKING +from ui.request.auth.auth_mixin import _AuthMixin -from PySide6.QtCore import Qt -from PySide6.QtWidgets import ( - QComboBox, - QHBoxLayout, - QLabel, - QLineEdit, - QStackedWidget, - QVBoxLayout, - QWidget, -) - -from ui.widgets.variable_line_edit import VariableLineEdit - -if TYPE_CHECKING: - from PySide6.QtCore import QTimer - -# Authorization type identifiers (must match editor_widget._AUTH_TYPES order) -_AUTH_TYPES = ("No Auth", "Bearer Token", "Basic Auth", "API Key") - - -class _AuthMixin: - """Mixin that adds auth tab building and auth data helpers. - - Expects the host class to provide ``_on_field_changed``, - ``_loading``, ``_set_dirty``, and ``_debounce_timer`` attributes. - """ - - # -- Host-class interface (declared for mypy) ----------------------- - _loading: bool - _debounce_timer: QTimer - - def _on_field_changed(self) -> None: ... - def _set_dirty(self, dirty: bool) -> None: ... - def _sync_tab_indicators(self) -> None: ... - - # -- UI construction (called from __init__) ------------------------- - - def _build_auth_tab(self, auth_layout: QVBoxLayout) -> None: - """Construct the auth tab contents: type selector + fields stack.""" - auth_type_row = QHBoxLayout() - auth_type_row.setSpacing(8) - auth_type_label = QLabel("Type:") - auth_type_label.setObjectName("sectionLabel") - auth_type_row.addWidget(auth_type_label) - - self._auth_type_combo = QComboBox() - self._auth_type_combo.addItems(list(_AUTH_TYPES)) - self._auth_type_combo.setFixedWidth(140) - self._auth_type_combo.currentTextChanged.connect(self._on_auth_type_changed) - auth_type_row.addWidget(self._auth_type_combo) - auth_type_row.addStretch() - auth_layout.addLayout(auth_type_row) - - self._auth_fields_stack = QStackedWidget() - - # No Auth page - no_auth_page = QLabel("This request does not use any authorization.") - no_auth_page.setAlignment(Qt.AlignmentFlag.AlignCenter) - no_auth_page.setObjectName("emptyStateLabel") - self._auth_fields_stack.addWidget(no_auth_page) - - # Bearer Token page - bearer_page = QWidget() - bearer_layout = QVBoxLayout(bearer_page) - bearer_layout.setContentsMargins(0, 8, 0, 0) - token_label = QLabel("Token") - token_label.setObjectName("sectionLabel") - bearer_layout.addWidget(token_label) - self._bearer_token_input = VariableLineEdit() - self._bearer_token_input.setPlaceholderText("Enter bearer token") - self._bearer_token_input.textChanged.connect(self._on_field_changed) - bearer_layout.addWidget(self._bearer_token_input) - bearer_layout.addStretch() - self._auth_fields_stack.addWidget(bearer_page) - - # Basic Auth page - basic_page = QWidget() - basic_layout = QVBoxLayout(basic_page) - basic_layout.setContentsMargins(0, 8, 0, 0) - username_label = QLabel("Username") - username_label.setObjectName("sectionLabel") - basic_layout.addWidget(username_label) - self._basic_username_input = VariableLineEdit() - self._basic_username_input.setPlaceholderText("Username") - self._basic_username_input.textChanged.connect(self._on_field_changed) - basic_layout.addWidget(self._basic_username_input) - password_label = QLabel("Password") - password_label.setObjectName("sectionLabel") - basic_layout.addWidget(password_label) - self._basic_password_input = VariableLineEdit() - self._basic_password_input.setPlaceholderText("Password") - self._basic_password_input.setEchoMode(QLineEdit.EchoMode.Password) - self._basic_password_input.textChanged.connect(self._on_field_changed) - basic_layout.addWidget(self._basic_password_input) - basic_layout.addStretch() - self._auth_fields_stack.addWidget(basic_page) - - # API Key page - apikey_page = QWidget() - apikey_layout = QVBoxLayout(apikey_page) - apikey_layout.setContentsMargins(0, 8, 0, 0) - key_label = QLabel("Key") - key_label.setObjectName("sectionLabel") - apikey_layout.addWidget(key_label) - self._apikey_key_input = VariableLineEdit() - self._apikey_key_input.setPlaceholderText("Header or query parameter name") - self._apikey_key_input.textChanged.connect(self._on_field_changed) - apikey_layout.addWidget(self._apikey_key_input) - value_label = QLabel("Value") - value_label.setObjectName("sectionLabel") - apikey_layout.addWidget(value_label) - self._apikey_value_input = VariableLineEdit() - self._apikey_value_input.setPlaceholderText("API key value") - self._apikey_value_input.textChanged.connect(self._on_field_changed) - apikey_layout.addWidget(self._apikey_value_input) - - add_to_label = QLabel("Add to") - add_to_label.setObjectName("sectionLabel") - apikey_layout.addWidget(add_to_label) - self._apikey_add_to_combo = QComboBox() - self._apikey_add_to_combo.addItems(["Header", "Query Params"]) - self._apikey_add_to_combo.setFixedWidth(140) - self._apikey_add_to_combo.currentTextChanged.connect(self._on_field_changed) - apikey_layout.addWidget(self._apikey_add_to_combo) - apikey_layout.addStretch() - self._auth_fields_stack.addWidget(apikey_page) - - auth_layout.addWidget(self._auth_fields_stack, 1) - - # -- Auth type switching ------------------------------------------- - - def _on_auth_type_changed(self, auth_type: str) -> None: - """Switch the auth fields stack page based on the selected type.""" - page_map = { - "No Auth": 0, - "Bearer Token": 1, - "Basic Auth": 2, - "API Key": 3, - } - self._auth_fields_stack.setCurrentIndex(page_map.get(auth_type, 0)) - if not self._loading: - self._set_dirty(True) - self._debounce_timer.start() - self._sync_tab_indicators() - - # -- Load / save helpers ------------------------------------------- - - def _load_auth(self, auth: dict) -> None: - """Populate auth fields from a Postman-format auth dict.""" - auth_type = auth.get("type", "noauth") - auth_type_map = { - "noauth": "No Auth", - "bearer": "Bearer Token", - "basic": "Basic Auth", - "apikey": "API Key", - } - self._auth_type_combo.setCurrentText(auth_type_map.get(auth_type, "No Auth")) - if auth_type == "bearer": - bearer_list = auth.get("bearer", []) - token = "" - for entry in bearer_list: - if entry.get("key") == "token": - token = entry.get("value", "") - self._bearer_token_input.setText(token) - elif auth_type == "basic": - basic_list = auth.get("basic", []) - username = password = "" - for entry in basic_list: - if entry.get("key") == "username": - username = entry.get("value", "") - elif entry.get("key") == "password": - password = entry.get("value", "") - self._basic_username_input.setText(username) - self._basic_password_input.setText(password) - elif auth_type == "apikey": - apikey_list = auth.get("apikey", []) - key = value = "" - add_to = "header" - for entry in apikey_list: - if entry.get("key") == "key": - key = entry.get("value", "") - elif entry.get("key") == "value": - value = entry.get("value", "") - elif entry.get("key") == "in": - add_to = entry.get("value", "header") - self._apikey_key_input.setText(key) - self._apikey_value_input.setText(value) - add_to_map = {"header": "Header", "query": "Query Params"} - self._apikey_add_to_combo.setCurrentText(add_to_map.get(add_to, "Header")) - - def _get_auth_data(self) -> dict | None: - """Build the auth configuration dict from the current UI state.""" - auth_type = self._auth_type_combo.currentText() - if auth_type == "No Auth": - return {"type": "noauth"} - if auth_type == "Bearer Token": - return { - "type": "bearer", - "bearer": [ - {"key": "token", "value": self._bearer_token_input.text(), "type": "string"}, - ], - } - if auth_type == "Basic Auth": - return { - "type": "basic", - "basic": [ - { - "key": "username", - "value": self._basic_username_input.text(), - "type": "string", - }, - { - "key": "password", - "value": self._basic_password_input.text(), - "type": "string", - }, - ], - } - if auth_type == "API Key": - add_to = "header" if self._apikey_add_to_combo.currentText() == "Header" else "query" - return { - "type": "apikey", - "apikey": [ - {"key": "key", "value": self._apikey_key_input.text(), "type": "string"}, - {"key": "value", "value": self._apikey_value_input.text(), "type": "string"}, - {"key": "in", "value": add_to, "type": "string"}, - ], - } - return None +__all__ = ["_AuthMixin"] diff --git a/src/ui/request/request_editor/editor_widget.py b/src/ui/request/request_editor/editor_widget.py index 82159ff..e6f984c 100644 --- a/src/ui/request/request_editor/editor_widget.py +++ b/src/ui/request/request_editor/editor_widget.py @@ -290,7 +290,7 @@ def load_request(self, data: RequestLoadDict, *, request_id: int | None = None) else: self._scripts_edit.setPlainText("") - self._load_auth(data.get("auth") or {}) + self._load_auth(data.get("auth")) self._description_edit.setPlainText(data.get("description") or "") self._set_dirty(False) finally: @@ -408,7 +408,7 @@ def clear_request(self) -> None: self._scripts_edit.clear() self._body_mode_buttons["none"].setChecked(True) self._raw_format_combo.setCurrentText("Text") - self._auth_type_combo.setCurrentText("No Auth") + self._auth_type_combo.setCurrentText("Inherit auth from parent") self._bearer_token_input.clear() self._basic_username_input.clear() self._basic_password_input.clear() @@ -439,11 +439,13 @@ def _on_field_changed(self) -> None: def _sync_tab_indicators(self) -> None: """Append a dot indicator to section tabs that contain data.""" + if not hasattr(self, "_scripts_edit"): + return has_content = [ bool(self._params_table.get_data()), bool(self._headers_table.get_data()), not self._body_mode_buttons.get("none", QRadioButton()).isChecked(), - self._auth_type_combo.currentText() != "No Auth", + self._auth_type_combo.currentText() not in ("No Auth", "Inherit auth from parent"), bool(self._description_edit.toPlainText().strip()), bool(self._scripts_edit.toPlainText().strip()), ] diff --git a/src/ui/sidebar/__init__.py b/src/ui/sidebar/__init__.py new file mode 100644 index 0000000..4a19b1a --- /dev/null +++ b/src/ui/sidebar/__init__.py @@ -0,0 +1,12 @@ +"""Right sidebar sub-package. + +Re-exports :class:`RightSidebar` for use from the main window: + + from ui.sidebar import RightSidebar +""" + +from __future__ import annotations + +from ui.sidebar.sidebar_widget import RightSidebar + +__all__ = ["RightSidebar"] diff --git a/src/ui/sidebar/sidebar_widget.py b/src/ui/sidebar/sidebar_widget.py new file mode 100644 index 0000000..51f45a8 --- /dev/null +++ b/src/ui/sidebar/sidebar_widget.py @@ -0,0 +1,395 @@ +"""Postman-style right sidebar with icon rail and flyout panel. + +The sidebar consists of two widgets placed as **separate children** in +the parent ``QSplitter``: + +- :class:`_FlyoutPanel` — collapsible content area (variables / + code-snippet). The QSplitter enforces its ``minimumSizeHint`` so + content is never crushed: dragging past the minimum snaps it to 0. +- :class:`RightSidebar` — the always-visible icon rail. + +``RightSidebar`` owns the flyout and exposes the same public API as +before. Call :pymethod:`install_in_splitter` after construction to +place both widgets into the target splitter. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from PySide6.QtCore import QSize, Qt +from PySide6.QtWidgets import ( + QLabel, + QPushButton, + QSizePolicy, + QSplitter, + QToolButton, + QVBoxLayout, + QWidget, +) + +from ui.sidebar.snippet_panel import SnippetPanel +from ui.sidebar.variables_panel import VariablesPanel +from ui.styling.icons import phi + +if TYPE_CHECKING: + from services.environment_service import LocalOverride, VariableDetail + + +# ------------------------------------------------------------------ +# Flyout panel — separate splitter child +# ------------------------------------------------------------------ +class _FlyoutPanel(QWidget): + """Collapsible content panel placed as its own splitter child.""" + + def __init__(self, parent: QWidget | None = None) -> None: + """Build title bar, variables panel and snippet panel.""" + super().__init__(parent) + self.setObjectName("sidebarPanelArea") + self.setAttribute(Qt.WidgetAttribute.WA_StyledBackground, True) + + em = self.fontMetrics().height() + self._min_width: int = round(12.0 * em) + + layout = QVBoxLayout(self) + layout.setContentsMargins(0, 0, 0, 0) + layout.setSpacing(0) + + # Title bar + from PySide6.QtWidgets import QHBoxLayout + + title_bar = QHBoxLayout() + title_bar.setContentsMargins(12, 8, 8, 4) + title_bar.setSpacing(4) + + self.title_label = QLabel() + self.title_label.setObjectName("sidebarTitleLabel") + title_bar.addWidget(self.title_label) + title_bar.addStretch() + + self.close_btn = QPushButton() + self.close_btn.setObjectName("iconButton") + self.close_btn.setFixedSize(28, 28) + self.close_btn.setIcon(phi("x", size=16)) + self.close_btn.setCursor(Qt.CursorShape.PointingHandCursor) + self.close_btn.setToolTip("Close panel") + title_bar.addWidget(self.close_btn) + + layout.addLayout(title_bar) + + # Content panels + self.variables_panel = VariablesPanel() + self.snippet_panel = SnippetPanel() + self.snippet_panel.setSizePolicy( + QSizePolicy.Policy.Preferred, + QSizePolicy.Policy.Expanding, + ) + layout.addWidget(self.variables_panel, 1) + layout.addWidget(self.snippet_panel, 1) + self.variables_panel.hide() + self.snippet_panel.hide() + + def minimumSizeHint(self) -> QSize: + """Enforce a readable minimum width for the flyout.""" + return QSize(self._min_width, 0) + + +# ------------------------------------------------------------------ +# Icon rail + controller +# ------------------------------------------------------------------ +class RightSidebar(QWidget): + """Always-visible icon rail that controls a flyout panel. + + After construction, call :pymethod:`install_in_splitter` to place + both the flyout and the rail into the parent splitter. + """ + + def __init__(self, parent: QWidget | None = None) -> None: + """Initialise the icon rail and create the flyout panel.""" + super().__init__(parent) + self.setObjectName("sidebarRail") + + # Derive sizes from the application font. + em = self.fontMetrics().height() + self._rail_width: int = round(2.0 * em) + self._icon_size: int = em + self._btn_size: int = self._rail_width - round(0.35 * em) + self._panel_hint_width: int = round(15.0 * em) + + self.setFixedWidth(self._rail_width) + + # --- Flyout (separate widget, placed in splitter later) ------- + self._flyout = _FlyoutPanel() + self._close_btn = self._flyout.close_btn + self._title_label = self._flyout.title_label + self._variables_panel = self._flyout.variables_panel + self._snippet_panel = self._flyout.snippet_panel + self._close_btn.clicked.connect(self._close_panel) + + # --- Rail layout ---------------------------------------------- + rail_layout = QVBoxLayout(self) + rail_layout.setContentsMargins(0, 6, 0, 6) + rail_layout.setSpacing(2) + + self._var_btn = self._make_rail_button( + "brackets-curly", + "Variables", + ) + self._snippet_btn = self._make_rail_button("code", "Code snippet") + self._snippet_btn.hide() + rail_layout.addWidget(self._var_btn) + rail_layout.addWidget(self._snippet_btn) + rail_layout.addStretch() + + # State + self._active_panel: str | None = None + self._last_panel: str | None = None + self._available_panels: set[str] = set() + self._default_panel: str | None = None + self._splitter: QSplitter | None = None + self._flyout_idx: int = -1 + + # Wire rail buttons + self._var_btn.clicked.connect(lambda: self._toggle_panel("variables")) + self._snippet_btn.clicked.connect( + lambda: self._toggle_panel("snippet"), + ) + + # Keep a reference for the ``_rail`` attribute used by tests. + @property + def _rail(self) -> QWidget: + """Return self — the rail *is* this widget.""" + return self + + # ------------------------------------------------------------------ + # Splitter integration + # ------------------------------------------------------------------ + def install_in_splitter(self, splitter: QSplitter) -> None: + """Add the flyout and rail as children of *splitter*. + + Must be called **after** the content area has been added to + the splitter so the flyout sits between content and rail. + """ + self._splitter = splitter + splitter.addWidget(self._flyout) + self._flyout_idx = splitter.indexOf(self._flyout) + splitter.addWidget(self) + rail_idx = splitter.indexOf(self) + + # Flyout: collapsible (snap-to-close), no stretch. + splitter.setCollapsible(self._flyout_idx, True) + splitter.setStretchFactor(self._flyout_idx, 0) + + # Rail: fixed, non-collapsible, no stretch. + splitter.setCollapsible(rail_idx, False) + splitter.setStretchFactor(rail_idx, 0) + + # Hide the handle between flyout and rail — it's not useful + # since the rail is fixed-width. + rail_handle = splitter.handle(rail_idx) + if rail_handle: + rail_handle.setFixedWidth(0) + rail_handle.setEnabled(False) + + # Start collapsed — flyout at 0 width. + sizes = splitter.sizes() + sizes[self._flyout_idx] = 0 + splitter.setSizes(sizes) + + # React when user drags the splitter handle. + splitter.splitterMoved.connect(self._on_splitter_moved) + + # ------------------------------------------------------------------ + # Public API + # ------------------------------------------------------------------ + @property + def variables_panel(self) -> VariablesPanel: + """Return the variables panel widget.""" + return self._variables_panel + + @property + def snippet_panel(self) -> SnippetPanel: + """Return the snippet panel widget.""" + return self._snippet_panel + + @property + def active_panel(self) -> str | None: + """Return the key of the currently open panel, or *None*.""" + return self._active_panel + + @property + def panel_open(self) -> bool: + """Return whether any panel is currently visible.""" + return self._active_panel is not None + + def show_request_panels( + self, + variables: dict[str, VariableDetail], + local_overrides: dict[str, LocalOverride] | None = None, + has_environment: bool = True, + *, + method: str = "", + url: str = "", + headers: str | None = None, + body: str | None = None, + auth: dict | None = None, + ) -> None: + """Configure the sidebar for a request tab.""" + self._available_panels = {"variables", "snippet"} + self._default_panel = "snippet" + self._var_btn.setEnabled(True) + self._snippet_btn.show() + self._snippet_btn.setEnabled(True) + + self._variables_panel.load_variables( + variables, + local_overrides=local_overrides, + has_environment=has_environment, + ) + self._snippet_panel.update_request( + method=method, + url=url, + headers=headers, + body=body, + auth=auth, + ) + + if self._active_panel and self._active_panel not in self._available_panels: + self._close_panel() + + def show_folder_panels( + self, + variables: dict[str, VariableDetail], + has_environment: bool = True, + ) -> None: + """Configure the sidebar for a folder tab.""" + self._available_panels = {"variables"} + self._default_panel = "variables" + self._var_btn.setEnabled(True) + self._snippet_btn.hide() + + self._variables_panel.load_variables( + variables, + has_environment=has_environment, + ) + + if self._active_panel == "snippet": + self._close_panel() + + def clear(self) -> None: + """Reset the sidebar to an empty state (no tab open).""" + self._available_panels = set() + self._var_btn.setEnabled(False) + self._snippet_btn.hide() + self._close_panel() + self._variables_panel.clear() + self._snippet_panel.clear() + + def open_panel(self, panel: str) -> None: + """Programmatically open a specific panel by key.""" + if panel in self._available_panels: + self._show_panel(panel) + + # ------------------------------------------------------------------ + # Internal helpers + # ------------------------------------------------------------------ + def _make_rail_button(self, icon_name: str, tooltip: str) -> QToolButton: + """Create a single rail icon button.""" + btn = QToolButton() + btn.setObjectName("sidebarRailButton") + btn.setIcon(phi(icon_name, size=self._icon_size)) + btn.setToolTip(tooltip) + btn.setCheckable(True) + btn.setFixedSize(self._btn_size, self._btn_size) + btn.setCursor(Qt.CursorShape.PointingHandCursor) + btn.setEnabled(False) + return btn + + def _toggle_panel(self, panel: str) -> None: + """Toggle the given panel open or closed.""" + if self._active_panel == panel: + self._close_panel() + else: + self._show_panel(panel) + + def _show_panel(self, panel: str) -> None: + """Open *panel*, configuring the flyout content.""" + self._active_panel = panel + self._last_panel = panel + self._variables_panel.setVisible(panel == "variables") + self._snippet_panel.setVisible(panel == "snippet") + self._var_btn.setChecked(panel == "variables") + self._snippet_btn.setChecked(panel == "snippet") + self._title_label.setText( + "Variables" if panel == "variables" else "Code snippet", + ) + self._flyout.show() + self._expand_flyout() + + def _close_panel(self) -> None: + """Collapse the flyout, keeping the icon rail visible.""" + self._active_panel = None + self._variables_panel.hide() + self._snippet_panel.hide() + self._var_btn.setChecked(False) + self._snippet_btn.setChecked(False) + self._collapse_flyout() + + def _expand_flyout(self) -> None: + """Expand the flyout in the parent splitter via setSizes.""" + if not self._splitter or self._flyout_idx < 0: + return + sizes = self._splitter.sizes() + if sizes[self._flyout_idx] >= self._panel_hint_width: + return + # Steal space from the content area (index 0). + need = self._panel_hint_width - sizes[self._flyout_idx] + give = min(need, sizes[0]) + sizes[0] -= give + sizes[self._flyout_idx] += give + self._splitter.setSizes(sizes) + + def _collapse_flyout(self) -> None: + """Collapse the flyout in the parent splitter to 0.""" + if not self._splitter or self._flyout_idx < 0: + return + sizes = self._splitter.sizes() + freed = sizes[self._flyout_idx] + if freed <= 0: + return + sizes[0] += freed + sizes[self._flyout_idx] = 0 + self._splitter.setSizes(sizes) + + def _on_splitter_moved(self, _pos: int, _index: int) -> None: + """React to the user dragging the splitter handle.""" + if not self._splitter or self._flyout_idx < 0: + return + flyout_width = self._splitter.sizes()[self._flyout_idx] + + if flyout_width == 0 and self._active_panel: + # User collapsed the flyout by dragging. + self._active_panel = None + self._variables_panel.hide() + self._snippet_panel.hide() + self._var_btn.setChecked(False) + self._snippet_btn.setChecked(False) + + if flyout_width > 0 and not self._active_panel: + # User expanded the flyout by dragging — open a panel. + panel = self._last_panel + if not panel or panel not in self._available_panels: + panel = self._default_panel + if panel: + # Only configure content — don't call _expand_flyout + # again since the user is already controlling the width. + self._active_panel = panel + self._last_panel = panel + self._variables_panel.setVisible(panel == "variables") + self._snippet_panel.setVisible(panel == "snippet") + self._var_btn.setChecked(panel == "variables") + self._snippet_btn.setChecked(panel == "snippet") + self._title_label.setText( + "Variables" if panel == "variables" else "Code snippet", + ) + self._flyout.show() diff --git a/src/ui/sidebar/snippet_panel.py b/src/ui/sidebar/snippet_panel.py new file mode 100644 index 0000000..a30cf50 --- /dev/null +++ b/src/ui/sidebar/snippet_panel.py @@ -0,0 +1,525 @@ +"""Code snippet panel for the right sidebar. + +Inline replacement for the former :class:`CodeSnippetDialog`. Embeds +a language selector, a read-only code editor, a copy-to-clipboard +button, and a settings popup directly inside the sidebar. +""" + +from __future__ import annotations + +import time + +from PySide6.QtCore import QEvent, QSettings, Qt +from PySide6.QtGui import QClipboard, QGuiApplication, QMouseEvent +from PySide6.QtWidgets import ( + QApplication, + QCheckBox, + QComboBox, + QFormLayout, + QFrame, + QHBoxLayout, + QLabel, + QPushButton, + QSpinBox, + QVBoxLayout, + QWidget, +) + +from services.http.snippet_generator import SnippetGenerator, SnippetOptions +from ui.styling.icons import phi +from ui.widgets.code_editor import CodeEditorWidget + +_SHOW_GRACE_SEC = 0.15 + +# QSettings keys +_SETTINGS_PREFIX = "snippet" +_KEY_LANGUAGE = f"{_SETTINGS_PREFIX}/last_language" +_KEY_INDENT_COUNT = f"{_SETTINGS_PREFIX}/indent_count" +_KEY_INDENT_TYPE = f"{_SETTINGS_PREFIX}/indent_type" +_KEY_TRIM_BODY = f"{_SETTINGS_PREFIX}/trim_body" +_KEY_FOLLOW_REDIRECT = f"{_SETTINGS_PREFIX}/follow_redirect" +_KEY_REQUEST_TIMEOUT = f"{_SETTINGS_PREFIX}/request_timeout" +_KEY_INCLUDE_BOILERPLATE = f"{_SETTINGS_PREFIX}/include_boilerplate" +_KEY_ASYNC_AWAIT = f"{_SETTINGS_PREFIX}/async_await" +_KEY_ES6_FEATURES = f"{_SETTINGS_PREFIX}/es6_features" +_KEY_MULTILINE = f"{_SETTINGS_PREFIX}/multiline" +_KEY_LONG_FORM = f"{_SETTINGS_PREFIX}/long_form" +_KEY_LINE_CONTINUATION = f"{_SETTINGS_PREFIX}/line_continuation" +_KEY_QUOTE_TYPE = f"{_SETTINGS_PREFIX}/quote_type" +_KEY_FOLLOW_ORIGINAL_METHOD = f"{_SETTINGS_PREFIX}/follow_original_method" +_KEY_SILENT_MODE = f"{_SETTINGS_PREFIX}/silent_mode" + + +class SnippetSettingsPopup(QFrame): + """Small floating popup for snippet generation settings. + + Positioned below the gear button. All changes apply immediately + (live preview) and persist via QSettings. + """ + + def __init__(self, parent: QWidget | None = None) -> None: + """Initialise the settings popup with controls for all options.""" + super().__init__(parent) + self.setWindowFlags( + Qt.WindowType.Tool + | Qt.WindowType.FramelessWindowHint + | Qt.WindowType.WindowStaysOnTopHint + ) + self.setAttribute(Qt.WidgetAttribute.WA_DeleteOnClose, False) + self.setObjectName("infoPopup") + self._show_time: float = 0.0 + + settings = QSettings() + form = QFormLayout(self) + form.setContentsMargins(12, 10, 12, 10) + form.setSpacing(6) + + # Indent count + self._indent_count = QSpinBox() + self._indent_count.setRange(1, 8) + self._indent_count.setValue(int(str(settings.value(_KEY_INDENT_COUNT, 2)))) + self._indent_count_label = QLabel("Indent count:") + form.addRow(self._indent_count_label, self._indent_count) + + # Indent type + self._indent_type = QComboBox() + self._indent_type.addItems(["Space", "Tab"]) + saved_type = str(settings.value(_KEY_INDENT_TYPE, "space")) + self._indent_type.setCurrentText("Tab" if saved_type == "tab" else "Space") + self._indent_type_label = QLabel("Indent type:") + form.addRow(self._indent_type_label, self._indent_type) + + # Trim body + self._trim_body = QCheckBox() + self._trim_body.setChecked(bool(settings.value(_KEY_TRIM_BODY, False))) + self._trim_body_label = QLabel("Trim body:") + form.addRow(self._trim_body_label, self._trim_body) + + # Follow redirect + self._follow_redirect = QCheckBox() + self._follow_redirect.setChecked( + settings.value(_KEY_FOLLOW_REDIRECT, True) not in (False, "false") + ) + self._redirect_label = QLabel("Follow redirects:") + form.addRow(self._redirect_label, self._follow_redirect) + + # Request timeout + self._request_timeout = QSpinBox() + self._request_timeout.setRange(0, 300) + self._request_timeout.setSuffix(" s") + self._request_timeout.setSpecialValueText("None") + self._request_timeout.setValue(int(str(settings.value(_KEY_REQUEST_TIMEOUT, 0)))) + self._timeout_label = QLabel("Timeout:") + form.addRow(self._timeout_label, self._request_timeout) + + # Include boilerplate + self._include_boilerplate = QCheckBox() + self._include_boilerplate.setChecked( + settings.value(_KEY_INCLUDE_BOILERPLATE, True) not in (False, "false") + ) + self._boilerplate_label = QLabel("Include boilerplate:") + form.addRow(self._boilerplate_label, self._include_boilerplate) + + # Async/await + self._async_await = QCheckBox() + self._async_await.setChecked(bool(settings.value(_KEY_ASYNC_AWAIT, False))) + self._async_label = QLabel("Async/await:") + form.addRow(self._async_label, self._async_await) + + # ES6 features + self._es6_features = QCheckBox() + self._es6_features.setChecked(bool(settings.value(_KEY_ES6_FEATURES, False))) + self._es6_label = QLabel("ES6 features:") + form.addRow(self._es6_label, self._es6_features) + + # --- cURL-specific options --- + + # Multiline + self._multiline = QCheckBox() + self._multiline.setChecked(settings.value(_KEY_MULTILINE, True) not in (False, "false")) + self._multiline_label = QLabel("Multiline:") + form.addRow(self._multiline_label, self._multiline) + + # Long-form options + self._long_form = QCheckBox() + self._long_form.setChecked(settings.value(_KEY_LONG_FORM, True) not in (False, "false")) + self._long_form_label = QLabel("Long form options:") + form.addRow(self._long_form_label, self._long_form) + + # Line continuation character + self._line_continuation = QComboBox() + self._line_continuation.addItems(["\\", "^", "`"]) + saved_cont = str(settings.value(_KEY_LINE_CONTINUATION, "\\")) + idx = self._line_continuation.findText(saved_cont) + if idx >= 0: + self._line_continuation.setCurrentIndex(idx) + self._continuation_label = QLabel("Line continuation:") + form.addRow(self._continuation_label, self._line_continuation) + + # Quote type + self._quote_type = QComboBox() + self._quote_type.addItems(["single", "double"]) + saved_quote = str(settings.value(_KEY_QUOTE_TYPE, "single")) + self._quote_type.setCurrentText(saved_quote) + self._quote_label = QLabel("Quote type:") + form.addRow(self._quote_label, self._quote_type) + + # Follow original method + self._follow_original_method = QCheckBox() + self._follow_original_method.setChecked( + bool(settings.value(_KEY_FOLLOW_ORIGINAL_METHOD, False)) + ) + self._orig_method_label = QLabel("Follow original method:") + form.addRow(self._orig_method_label, self._follow_original_method) + + # Silent mode + self._silent_mode = QCheckBox() + self._silent_mode.setChecked(bool(settings.value(_KEY_SILENT_MODE, False))) + self._silent_label = QLabel("Silent mode:") + form.addRow(self._silent_label, self._silent_mode) + + # Connect for live preview + self._indent_count.valueChanged.connect(self._save) + self._indent_type.currentTextChanged.connect(self._save) + self._trim_body.toggled.connect(self._save) + self._follow_redirect.toggled.connect(self._save) + self._request_timeout.valueChanged.connect(self._save) + self._include_boilerplate.toggled.connect(self._save) + self._async_await.toggled.connect(self._save) + self._es6_features.toggled.connect(self._save) + self._multiline.toggled.connect(self._save) + self._long_form.toggled.connect(self._save) + self._line_continuation.currentTextChanged.connect(self._save) + self._quote_type.currentTextChanged.connect(self._save) + self._follow_original_method.toggled.connect(self._save) + self._silent_mode.toggled.connect(self._save) + + self._on_settings_changed: list[object] = [] + + def set_language_options(self, applicable: tuple[str, ...]) -> None: + """Show/hide controls based on the current language's options.""" + has_indent = "indent_count" in applicable + self._indent_count_label.setVisible(has_indent) + self._indent_count.setVisible(has_indent) + self._indent_type_label.setVisible(has_indent) + self._indent_type.setVisible(has_indent) + + has_trim = "trim_body" in applicable + self._trim_body_label.setVisible(has_trim) + self._trim_body.setVisible(has_trim) + + has_redirect = "follow_redirect" in applicable + self._redirect_label.setVisible(has_redirect) + self._follow_redirect.setVisible(has_redirect) + + has_timeout = "request_timeout" in applicable + self._timeout_label.setVisible(has_timeout) + self._request_timeout.setVisible(has_timeout) + + has_boilerplate = "include_boilerplate" in applicable + self._boilerplate_label.setVisible(has_boilerplate) + self._include_boilerplate.setVisible(has_boilerplate) + + has_async = "async_await" in applicable + self._async_label.setVisible(has_async) + self._async_await.setVisible(has_async) + + has_es6 = "es6_features" in applicable + self._es6_label.setVisible(has_es6) + self._es6_features.setVisible(has_es6) + + has_multiline = "multiline" in applicable + self._multiline_label.setVisible(has_multiline) + self._multiline.setVisible(has_multiline) + + has_long_form = "long_form" in applicable + self._long_form_label.setVisible(has_long_form) + self._long_form.setVisible(has_long_form) + + has_continuation = "line_continuation" in applicable + self._continuation_label.setVisible(has_continuation) + self._line_continuation.setVisible(has_continuation) + + has_quote = "quote_type" in applicable + self._quote_label.setVisible(has_quote) + self._quote_type.setVisible(has_quote) + + has_orig_method = "follow_original_method" in applicable + self._orig_method_label.setVisible(has_orig_method) + self._follow_original_method.setVisible(has_orig_method) + + has_silent = "silent_mode" in applicable + self._silent_label.setVisible(has_silent) + self._silent_mode.setVisible(has_silent) + + def get_options(self) -> SnippetOptions: + """Build a :class:`SnippetOptions` from current control values.""" + return SnippetOptions( + indent_count=self._indent_count.value(), + indent_type="tab" if self._indent_type.currentText() == "Tab" else "space", + trim_body=self._trim_body.isChecked(), + follow_redirect=self._follow_redirect.isChecked(), + request_timeout=self._request_timeout.value(), + include_boilerplate=self._include_boilerplate.isChecked(), + async_await=self._async_await.isChecked(), + es6_features=self._es6_features.isChecked(), + multiline=self._multiline.isChecked(), + long_form=self._long_form.isChecked(), + line_continuation=self._line_continuation.currentText(), + quote_type=self._quote_type.currentText(), + follow_original_method=self._follow_original_method.isChecked(), + silent_mode=self._silent_mode.isChecked(), + ) + + def on_settings_changed(self, callback: object) -> None: + """Register a callback invoked when any setting changes.""" + self._on_settings_changed.append(callback) + + def show_below(self, anchor: QWidget) -> None: + """Position the popup below *anchor* and show it.""" + pos = anchor.mapToGlobal(anchor.rect().bottomLeft()) + self.move(pos) + self._show_time = time.monotonic() + self.show() + self.activateWindow() + self.setFocus() + app = QApplication.instance() + if app is not None: + app.installEventFilter(self) + + def eventFilter(self, obj: QWidget, event: QEvent) -> bool: # type: ignore[override] + """Close on click-outside or parent window move/resize.""" + etype = event.type() + if ( + etype in (QEvent.Type.Move, QEvent.Type.Resize) + and obj is not self + and hasattr(obj, "isWindow") + and obj.isWindow() # type: ignore[union-attr] + ): + self.close() + return False + if etype == QEvent.Type.MouseButtonPress and isinstance(event, QMouseEvent): + if time.monotonic() - self._show_time < _SHOW_GRACE_SEC: + return False + if not self.geometry().contains(event.globalPosition().toPoint()): + self.close() + return False + return super().eventFilter(obj, event) + + def closeEvent(self, event: object) -> None: + """Remove the app-wide event filter when the popup closes.""" + app = QApplication.instance() + if app is not None: + app.removeEventFilter(self) + super().closeEvent(event) # type: ignore[arg-type] + + def _save(self) -> None: + """Persist current values to QSettings and notify listeners.""" + settings = QSettings() + settings.setValue(_KEY_INDENT_COUNT, self._indent_count.value()) + settings.setValue( + _KEY_INDENT_TYPE, + "tab" if self._indent_type.currentText() == "Tab" else "space", + ) + settings.setValue(_KEY_TRIM_BODY, self._trim_body.isChecked()) + settings.setValue(_KEY_FOLLOW_REDIRECT, self._follow_redirect.isChecked()) + settings.setValue(_KEY_REQUEST_TIMEOUT, self._request_timeout.value()) + settings.setValue(_KEY_INCLUDE_BOILERPLATE, self._include_boilerplate.isChecked()) + settings.setValue(_KEY_ASYNC_AWAIT, self._async_await.isChecked()) + settings.setValue(_KEY_ES6_FEATURES, self._es6_features.isChecked()) + settings.setValue(_KEY_MULTILINE, self._multiline.isChecked()) + settings.setValue(_KEY_LONG_FORM, self._long_form.isChecked()) + settings.setValue(_KEY_LINE_CONTINUATION, self._line_continuation.currentText()) + settings.setValue(_KEY_QUOTE_TYPE, self._quote_type.currentText()) + settings.setValue(_KEY_FOLLOW_ORIGINAL_METHOD, self._follow_original_method.isChecked()) + settings.setValue(_KEY_SILENT_MODE, self._silent_mode.isChecked()) + + for cb in self._on_settings_changed: + if callable(cb): + cb() + + +class SnippetPanel(QWidget): + """Inline code snippet generator panel. + + Displays the current request as a code snippet in the user's + chosen language. Call :meth:`update_request` whenever the + request editor state changes. + """ + + def __init__(self, parent: QWidget | None = None) -> None: + """Initialise the snippet panel with default empty state.""" + super().__init__(parent) + + self._method = "" + self._url = "" + self._headers: str | None = None + self._body: str | None = None + self._auth: dict | None = None + + layout = QVBoxLayout(self) + layout.setContentsMargins(8, 4, 8, 8) + layout.setSpacing(6) + + # Language selector row + selector_row = QHBoxLayout() + selector_row.setContentsMargins(0, 0, 0, 0) + selector_row.setSpacing(6) + + self._lang_combo = QComboBox() + self._lang_combo.addItems(SnippetGenerator.available_languages()) + self._lang_combo.setFixedHeight(28) + + # Restore last-selected language + settings = QSettings() + saved_lang = str(settings.value(_KEY_LANGUAGE, "cURL")) + idx = self._lang_combo.findText(saved_lang) + if idx >= 0: + self._lang_combo.setCurrentIndex(idx) + + self._lang_combo.currentTextChanged.connect(self._on_language_changed) + selector_row.addWidget(self._lang_combo, 1) + + # Settings gear button + self._settings_btn = QPushButton() + self._settings_btn.setIcon(phi("gear")) + self._settings_btn.setObjectName("iconButton") + self._settings_btn.setFixedSize(28, 28) + self._settings_btn.setCursor(Qt.CursorShape.PointingHandCursor) + self._settings_btn.setToolTip("Snippet settings") + self._settings_btn.clicked.connect(self._toggle_settings) + selector_row.addWidget(self._settings_btn) + + self._copy_btn = QPushButton() + self._copy_btn.setIcon(phi("clipboard")) + self._copy_btn.setObjectName("iconButton") + self._copy_btn.setFixedSize(28, 28) + self._copy_btn.setCursor(Qt.CursorShape.PointingHandCursor) + self._copy_btn.setToolTip("Copy to clipboard") + self._copy_btn.clicked.connect(self._copy_to_clipboard) + selector_row.addWidget(self._copy_btn) + + layout.addLayout(selector_row) + + # Code editor (read-only) + self._code_edit = CodeEditorWidget(read_only=True) + self._code_edit.setMinimumHeight(120) + layout.addWidget(self._code_edit, 1) + + # Status label + self._status_label = QLabel("") + self._status_label.setObjectName("mutedLabel") + layout.addWidget(self._status_label) + + # Settings popup (lazy) + self._settings_popup: SnippetSettingsPopup | None = None + + # ------------------------------------------------------------------ + # Public API + # ------------------------------------------------------------------ + def update_request( + self, + *, + method: str, + url: str, + headers: str | None = None, + body: str | None = None, + auth: dict | None = None, + ) -> None: + """Update the stored request data and regenerate the snippet.""" + self._method = method + self._url = url + self._headers = headers + self._body = body + self._auth = auth + self._refresh() + + def clear(self) -> None: + """Reset the panel to an empty state.""" + self._method = "" + self._url = "" + self._headers = None + self._body = None + self._auth = None + self._code_edit.set_text("") + self._status_label.setText("") + + # ------------------------------------------------------------------ + # Internal + # ------------------------------------------------------------------ + def _current_options(self) -> SnippetOptions | None: + """Build options from the settings popup or QSettings.""" + if self._settings_popup is not None: + return self._settings_popup.get_options() + settings = QSettings() + return SnippetOptions( + indent_count=int(str(settings.value(_KEY_INDENT_COUNT, 2))), + indent_type=str(settings.value(_KEY_INDENT_TYPE, "space")), + trim_body=settings.value(_KEY_TRIM_BODY, False) not in (False, "false"), + follow_redirect=settings.value(_KEY_FOLLOW_REDIRECT, True) not in (False, "false"), + request_timeout=int(str(settings.value(_KEY_REQUEST_TIMEOUT, 0))), + include_boilerplate=settings.value(_KEY_INCLUDE_BOILERPLATE, True) + not in (False, "false"), + async_await=settings.value(_KEY_ASYNC_AWAIT, False) not in (False, "false", 0, "0"), + es6_features=settings.value(_KEY_ES6_FEATURES, False) not in (False, "false", 0, "0"), + multiline=settings.value(_KEY_MULTILINE, True) not in (False, "false"), + long_form=settings.value(_KEY_LONG_FORM, True) not in (False, "false"), + line_continuation=str(settings.value(_KEY_LINE_CONTINUATION, "\\")), + quote_type=str(settings.value(_KEY_QUOTE_TYPE, "single")), + follow_original_method=settings.value(_KEY_FOLLOW_ORIGINAL_METHOD, False) + not in (False, "false", 0, "0"), + silent_mode=settings.value(_KEY_SILENT_MODE, False) not in (False, "false", 0, "0"), + ) + + def _refresh(self) -> None: + """Regenerate the snippet for the selected language.""" + if not self._url: + self._code_edit.set_text("") + return + lang = self._lang_combo.currentText() + snippet = SnippetGenerator.generate( + lang, + method=self._method, + url=self._url, + headers=self._headers, + body=self._body, + auth=self._auth, + options=self._current_options(), + ) + info = SnippetGenerator.get_language_info(lang) + lexer = info.lexer if info else "text" + self._code_edit.set_language(lexer) + self._code_edit.set_text(snippet) + self._status_label.setText("") + + def _on_language_changed(self, lang: str) -> None: + """Handle language combo change — persist and refresh.""" + QSettings().setValue(_KEY_LANGUAGE, lang) + if self._settings_popup is not None: + info = SnippetGenerator.get_language_info(lang) + if info: + self._settings_popup.set_language_options(info.applicable_options) + self._refresh() + + def _toggle_settings(self) -> None: + """Show or hide the snippet settings popup.""" + if self._settings_popup is not None and self._settings_popup.isVisible(): + self._settings_popup.hide() + return + + if self._settings_popup is None: + self._settings_popup = SnippetSettingsPopup(self) + self._settings_popup.on_settings_changed(self._refresh) + + lang = self._lang_combo.currentText() + info = SnippetGenerator.get_language_info(lang) + if info: + self._settings_popup.set_language_options(info.applicable_options) + self._settings_popup.show_below(self._settings_btn) + + def _copy_to_clipboard(self) -> None: + """Copy the current snippet text to the system clipboard.""" + clipboard: QClipboard | None = QGuiApplication.clipboard() + if clipboard is not None: + clipboard.setText(self._code_edit.toPlainText()) + self._status_label.setText("Copied!") diff --git a/src/ui/sidebar/variables_panel.py b/src/ui/sidebar/variables_panel.py new file mode 100644 index 0000000..d6f6dec --- /dev/null +++ b/src/ui/sidebar/variables_panel.py @@ -0,0 +1,256 @@ +"""Variables panel for the right sidebar. + +Displays resolved variables grouped by source (Environment, Collection, +Local Overrides) in a read-only list. The panel accepts a +:class:`VariableDetail` map and optional :class:`LocalOverride` map, +groups entries by their ``source`` field, and renders them as +key-value rows under collapsible source headings. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from PySide6.QtCore import Qt +from PySide6.QtGui import QPainter +from PySide6.QtWidgets import QHBoxLayout, QLabel, QScrollArea, QSizePolicy, QVBoxLayout, QWidget + +if TYPE_CHECKING: + from services.environment_service import LocalOverride, VariableDetail + + +class _ElidedLabel(QLabel): + """QLabel that elides text with an ellipsis when space is tight.""" + + def __init__(self, text: str = "", parent: QWidget | None = None) -> None: + super().__init__(text, parent) + self._full_text = text + + def setText(self, text: str) -> None: + """Store full text and trigger repaint.""" + self._full_text = text + super().setText(text) + + def paintEvent(self, event: object) -> None: + """Draw text with right-elision when it overflows.""" + painter = QPainter(self) + fm = self.fontMetrics() + elided = fm.elidedText( + self._full_text, + Qt.TextElideMode.ElideRight, + self.width(), + ) + painter.setPen(self.palette().color(self.foregroundRole())) + painter.drawText(self.rect(), int(Qt.AlignmentFlag.AlignVCenter), elided) + painter.end() + + +class VariablesPanel(QWidget): + """Read-only panel showing variables grouped by source. + + Sections are rendered for **Environment**, **Collection**, and + **Local Overrides** (only when entries exist for each). + """ + + def __init__(self, parent: QWidget | None = None) -> None: + """Initialise the variables panel with an empty state.""" + super().__init__(parent) + + layout = QVBoxLayout(self) + layout.setContentsMargins(0, 0, 0, 0) + layout.setSpacing(0) + + # Scrollable content area + self._scroll = QScrollArea() + self._scroll.setWidgetResizable(True) + self._scroll.setFrameShape(QScrollArea.Shape.NoFrame) + layout.addWidget(self._scroll) + + self._content = QWidget() + self._content_layout = QVBoxLayout(self._content) + self._content_layout.setContentsMargins(12, 8, 12, 8) + self._content_layout.setSpacing(0) + self._content_layout.addStretch() + self._scroll.setWidget(self._content) + + self._show_empty_state() + + # ------------------------------------------------------------------ + # Public API + # ------------------------------------------------------------------ + def load_variables( + self, + variables: dict[str, VariableDetail], + local_overrides: dict[str, LocalOverride] | None = None, + has_environment: bool = True, + ) -> None: + """Populate the panel with variable data grouped by source. + + *variables* maps variable names to their resolved metadata. + *local_overrides* are per-request overrides (shown separately). + *has_environment* controls whether the environment header is + shown or a 'No environment selected' hint appears. + """ + self._clear_content() + + # Group variables by source + env_vars: list[tuple[str, str]] = [] + coll_vars: list[tuple[str, str]] = [] + for name, detail in sorted(variables.items()): + if detail.get("is_local"): + continue # handled via local_overrides + if detail["source"] == "environment": + env_vars.append((name, detail["value"])) + elif detail["source"] == "collection": + coll_vars.append((name, detail["value"])) + + local_vars: list[tuple[str, str]] = [] + if local_overrides: + for name, override in sorted(local_overrides.items()): + local_vars.append((name, override["value"])) + + has_any = bool(env_vars or coll_vars or local_vars) + + # 1. Environment section + if not has_environment: + hint = QLabel("No environment selected. Select environment") + hint.setObjectName("emptyStateLabel") + hint.setWordWrap(True) + self._content_layout.addWidget(hint) + self._add_separator() + elif env_vars: + self._add_section("Environment", "environment", env_vars) + else: + self._add_section_header("Environment", "environment") + empty = QLabel("No variables") + empty.setObjectName("mutedLabel") + self._content_layout.addWidget(empty) + self._add_separator() + + # 2. Collection section + if coll_vars: + self._add_section("Requests collection", "collection", coll_vars) + elif has_any or not has_environment: + self._add_section_header("Requests collection", "collection") + empty = QLabel("No variables") + empty.setObjectName("mutedLabel") + self._content_layout.addWidget(empty) + self._add_separator() + + # 3. Local overrides section (only when entries exist) + if local_vars: + self._add_section("Local overrides", "local", local_vars) + + if not has_any and has_environment: + self._show_empty_state() + return + + self._content_layout.addStretch() + + def clear(self) -> None: + """Reset the panel to its empty state.""" + self._clear_content() + self._show_empty_state() + + # ------------------------------------------------------------------ + # Internal helpers + # ------------------------------------------------------------------ + def _clear_content(self) -> None: + """Remove all widgets and sub-layouts from the content layout.""" + while self._content_layout.count(): + item = self._content_layout.takeAt(0) + if item is None: + continue + widget = item.widget() + if widget is not None: + widget.setParent(None) + continue + sub = item.layout() + if sub is not None: + self._clear_layout(sub) + + @staticmethod + def _clear_layout(layout: object) -> None: + """Recursively delete all items in a sub-layout.""" + from PySide6.QtWidgets import QLayout + + if not isinstance(layout, QLayout): + return + while layout.count(): + child = layout.takeAt(0) + if child is None: + continue + w = child.widget() + if w is not None: + w.setParent(None) + else: + VariablesPanel._clear_layout(child.layout()) + + def _show_empty_state(self) -> None: + """Display an empty-state label when no variables are available.""" + label = QLabel("No variables available") + label.setObjectName("emptyStateLabel") + label.setAlignment(Qt.AlignmentFlag.AlignCenter) + self._content_layout.addWidget(label) + self._content_layout.addStretch() + + def _add_section( + self, + title: str, + source: str, + variables: list[tuple[str, str]], + ) -> None: + """Add a source section with header and variable rows.""" + self._add_section_header(title, source) + for name, value in variables: + self._add_variable_row(name, value) + self._add_separator() + + def _add_section_header(self, title: str, source: str) -> None: + """Add a section header with a colored source dot and title.""" + row = QHBoxLayout() + row.setContentsMargins(0, 8, 0, 4) + row.setSpacing(6) + + dot = QLabel("\u2022") + dot.setObjectName("sidebarSourceDot") + dot.setProperty("varSource", source) + dot.setFixedWidth(12) + dot.setAlignment(Qt.AlignmentFlag.AlignCenter) + row.addWidget(dot) + + label = QLabel(title) + label.setObjectName("sidebarSectionLabel") + row.addWidget(label) + row.addStretch() + + self._content_layout.addLayout(row) + + def _add_variable_row(self, name: str, value: str) -> None: + """Add a single key-value variable row.""" + row = QHBoxLayout() + row.setContentsMargins(18, 2, 0, 2) + row.setSpacing(12) + + key_label = _ElidedLabel(name) + key_label.setObjectName("variableKeyLabel") + key_label.setFixedWidth(120) + key_label.setSizePolicy(QSizePolicy.Policy.Fixed, QSizePolicy.Policy.Fixed) + row.addWidget(key_label) + + val_label = _ElidedLabel(value) + val_label.setObjectName("variableValueLabel") + val_label.setToolTip(value) + val_label.setSizePolicy(QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Fixed) + val_label.setMinimumWidth(0) + row.addWidget(val_label, 1) + + self._content_layout.addLayout(row) + + def _add_separator(self) -> None: + """Add a thin horizontal separator line.""" + sep = QLabel() + sep.setObjectName("sidebarSeparator") + sep.setFixedHeight(1) + self._content_layout.addWidget(sep) + self._content_layout.addWidget(sep) diff --git a/src/ui/styling/global_qss.py b/src/ui/styling/global_qss.py index e5e215b..3b69de4 100644 --- a/src/ui/styling/global_qss.py +++ b/src/ui/styling/global_qss.py @@ -561,4 +561,113 @@ def build_global_qss(p: ThemePalette) -> str: padding: 6px 8px; font-size: 11px; }} + + /* ---- New-item dialog (icon grid) ------------------------------ */ + QDialog[objectName="newItemPopup"] {{ + background: {p["bg"]}; + }} + QLabel[objectName="newItemTitle"] {{ + font-size: 14px; + font-weight: 600; + color: {p["text"]}; + }} + QPushButton[objectName="newItemTile"] {{ + background: {p["bg_alt"]}; + border: 1px solid {p["border"]}; + border-radius: 8px; + }} + QPushButton[objectName="newItemTile"]:hover {{ + border-color: {p["accent"]}; + background: {"rgba(79,193,255,0.08)" if p is DARK_PALETTE else "rgba(52,152,219,0.06)"}; + }} + QLabel[objectName="newItemTileLabel"] {{ + font-size: 12px; + font-weight: 500; + color: {p["text"]}; + }} + QLabel[objectName="newItemDescription"] {{ + font-size: 11px; + color: {p["text_muted"]}; + padding: 8px 4px 0px 4px; + }} + + /* ---- Save-request dialog ------------------------------------ */ + QTreeWidget[objectName="collectionTree"] {{ + border: 1px solid {p["border"]}; + background: {p["input_bg"]}; + border-radius: 4px; + outline: none; + }} + QTreeWidget[objectName="collectionTree"]::item {{ + padding: 6px 8px; + border: none; + }} + QTreeWidget[objectName="collectionTree"]::item:hover {{ + background: {p["hover_tree_bg"]}; + }} + QTreeWidget[objectName="collectionTree"]::item:selected {{ + background: {p["selected_bg"]}; + color: {p["text"]}; + }} + + /* ---- Right sidebar ------------------------------------------ */ + QWidget[objectName="sidebarPanelArea"] {{ + background: {p["bg"]}; + border-right: 1px solid {p["border"]}; + }} + QWidget[objectName="sidebarRail"] {{ + background: {p["bg"]}; + border-left: 1px solid {p["border"]}; + }} + QToolButton[objectName="sidebarRailButton"] {{ + background: transparent; + border: none; + border-radius: 4px; + margin: 2px 3px; + color: {p["text_muted"]}; + }} + QToolButton[objectName="sidebarRailButton"]:hover {{ + background: {"rgba(255,255,255,0.06)" if p is DARK_PALETTE else "rgba(0,0,0,0.05)"}; + }} + QToolButton[objectName="sidebarRailButton"]:checked {{ + background: {"rgba(255,255,255,0.10)" if p is DARK_PALETTE else "rgba(0,0,0,0.08)"}; + color: {p["text"]}; + }} + QToolButton[objectName="sidebarRailButton"]:disabled {{ + color: {p["text_muted"]}; + opacity: 0.4; + }} + QLabel[objectName="sidebarTitleLabel"] {{ + font-weight: bold; + font-size: 13px; + color: {p["text"]}; + }} + QLabel[objectName="variableKeyLabel"] {{ + font-family: monospace; + font-size: 12px; + color: {p["text"]}; + }} + QLabel[objectName="variableValueLabel"] {{ + font-family: monospace; + font-size: 12px; + color: {p["text_muted"]}; + }} + QLabel[objectName="sidebarSourceDot"] {{ + font-size: 16px; + font-weight: bold; + }} + QLabel[objectName="sidebarSourceDot"][varSource="environment"] {{ + color: {p["accent"]}; + }} + QLabel[objectName="sidebarSourceDot"][varSource="collection"] {{ + color: {p["success"]}; + }} + QLabel[objectName="sidebarSourceDot"][varSource="local"] {{ + color: {p["warning"]}; + }} + QLabel[objectName="sidebarSeparator"] {{ + background: {p["border"]}; + margin-top: 4px; + margin-bottom: 4px; + }} """ diff --git a/tests/ui/collections/test_collection_header.py b/tests/ui/collections/test_collection_header.py index b5e404c..e0b182e 100644 --- a/tests/ui/collections/test_collection_header.py +++ b/tests/ui/collections/test_collection_header.py @@ -18,15 +18,12 @@ def test_construction(self, qapp: QApplication, qtbot) -> None: assert header.height() == 70 def test_new_collection_signal(self, qapp: QApplication, qtbot) -> None: - """Clicking the + menu emits ``new_collection_requested(None)``.""" + """Clicking the collection tile emits ``new_collection_requested(None)``.""" header = CollectionHeader() qtbot.addWidget(header) with qtbot.waitSignal(header.new_collection_requested, timeout=1000) as blocker: - # Directly trigger the action instead of clicking through the menu - actions = header._plus_menu.actions() - assert len(actions) >= 1, "Plus menu should have at least one action" - actions[0].trigger() + header._popup.new_collection_clicked.emit() assert blocker.args == [None] @@ -40,40 +37,23 @@ def test_search_changed_signal(self, qapp: QApplication, qtbot) -> None: assert blocker.args == ["hello"] - def test_new_request_disabled_by_default(self, qapp: QApplication, qtbot) -> None: - """The 'New request' action is disabled when no collection is selected.""" + def test_new_request_emits_draft_signal(self, qapp: QApplication, qtbot) -> None: + """Clicking the request tile emits ``new_request_requested(None)`` (draft).""" header = CollectionHeader() qtbot.addWidget(header) - assert not header._new_req_act.isEnabled() - - def test_new_request_enabled_after_selection(self, qapp: QApplication, qtbot) -> None: - """Setting a selected collection ID enables the 'New request' action.""" - header = CollectionHeader() - qtbot.addWidget(header) - - header.set_selected_collection_id(42) - assert header._new_req_act.isEnabled() - - def test_new_request_emits_signal(self, qapp: QApplication, qtbot) -> None: - """Triggering 'New request' emits ``new_request_requested`` with the ID.""" - header = CollectionHeader() - qtbot.addWidget(header) - - header.set_selected_collection_id(42) - with qtbot.waitSignal(header.new_request_requested, timeout=1000) as blocker: - header._new_req_act.trigger() + header._popup.new_request_clicked.emit() - assert blocker.args == [42] + assert blocker.args == [None] - def test_new_request_disabled_on_none_selection(self, qapp: QApplication, qtbot) -> None: - """Clearing the selection disables the 'New request' action.""" + def test_set_selected_collection_id(self, qapp: QApplication, qtbot) -> None: + """``set_selected_collection_id`` stores the collection ID.""" header = CollectionHeader() qtbot.addWidget(header) header.set_selected_collection_id(42) - assert header._new_req_act.isEnabled() + assert header._selected_collection_id == 42 header.set_selected_collection_id(None) - assert not header._new_req_act.isEnabled() + assert header._selected_collection_id is None diff --git a/tests/ui/collections/test_new_item_popup.py b/tests/ui/collections/test_new_item_popup.py new file mode 100644 index 0000000..4473d5d --- /dev/null +++ b/tests/ui/collections/test_new_item_popup.py @@ -0,0 +1,100 @@ +"""Tests for the NewItemPopup dialog.""" + +from __future__ import annotations + +from PySide6.QtCore import Qt +from PySide6.QtWidgets import QApplication, QDialog, QPushButton + +from ui.collections.new_item_popup import NewItemPopup, _Tile + + +class TestTile: + """Tests for the internal _Tile button widget.""" + + def test_construction(self, qapp: QApplication, qtbot) -> None: + """_Tile can be instantiated with an icon name and label.""" + tile = _Tile("globe", "HTTP Request") + qtbot.addWidget(tile) + assert tile.objectName() == "newItemTile" + + def test_fixed_size(self, qapp: QApplication, qtbot) -> None: + """_Tile has fixed 140x110 size.""" + tile = _Tile("globe", "HTTP Request") + qtbot.addWidget(tile) + assert tile.width() == 140 + assert tile.height() == 110 + + def test_cursor_is_hand(self, qapp: QApplication, qtbot) -> None: + """_Tile shows a pointing hand cursor.""" + tile = _Tile("globe", "HTTP Request") + qtbot.addWidget(tile) + assert tile.cursor().shape() == Qt.CursorShape.PointingHandCursor + + +class TestNewItemPopup: + """Tests for the NewItemPopup dialog window.""" + + def test_construction(self, qapp: QApplication, qtbot) -> None: + """NewItemPopup can be instantiated without errors.""" + popup = NewItemPopup() + qtbot.addWidget(popup) + assert popup.objectName() == "newItemPopup" + + def test_is_qdialog(self, qapp: QApplication, qtbot) -> None: + """NewItemPopup is a QDialog.""" + popup = NewItemPopup() + qtbot.addWidget(popup) + assert isinstance(popup, QDialog) + + def test_window_title(self, qapp: QApplication, qtbot) -> None: + """Dialog has the 'Create New' title.""" + popup = NewItemPopup() + qtbot.addWidget(popup) + assert popup.windowTitle() == "Create New" + + def test_has_description_label(self, qapp: QApplication, qtbot) -> None: + """Dialog contains a description label with default text.""" + popup = NewItemPopup() + qtbot.addWidget(popup) + assert popup._description is not None + assert "HTTP request" in popup._description.text() + + def test_new_request_signal(self, qapp: QApplication, qtbot) -> None: + """Clicking the HTTP tile emits ``new_request_clicked``.""" + popup = NewItemPopup() + qtbot.addWidget(popup) + + http_tile = popup.findChildren(QPushButton)[0] + with qtbot.waitSignal(popup.new_request_clicked, timeout=1000): + http_tile.click() + + def test_new_collection_signal(self, qapp: QApplication, qtbot) -> None: + """Clicking the Collection tile emits ``new_collection_clicked``.""" + popup = NewItemPopup() + qtbot.addWidget(popup) + + tiles = popup.findChildren(QPushButton) + collection_tile = tiles[1] + with qtbot.waitSignal(popup.new_collection_clicked, timeout=1000): + collection_tile.click() + + def test_accept_on_http_click(self, qapp: QApplication, qtbot) -> None: + """Dialog accepts after clicking the HTTP tile.""" + popup = NewItemPopup() + qtbot.addWidget(popup) + + # _on_http_clicked should emit + accept + results: list[int] = [] + popup.finished.connect(lambda r: results.append(r)) + popup._on_http_clicked() + assert results == [int(QDialog.DialogCode.Accepted)] + + def test_accept_on_collection_click(self, qapp: QApplication, qtbot) -> None: + """Dialog accepts after clicking the Collection tile.""" + popup = NewItemPopup() + qtbot.addWidget(popup) + + results: list[int] = [] + popup.finished.connect(lambda r: results.append(r)) + popup._on_collection_clicked() + assert results == [int(QDialog.DialogCode.Accepted)] diff --git a/tests/ui/dialogs/test_save_request_dialog.py b/tests/ui/dialogs/test_save_request_dialog.py new file mode 100644 index 0000000..d7e5fbc --- /dev/null +++ b/tests/ui/dialogs/test_save_request_dialog.py @@ -0,0 +1,188 @@ +"""Tests for the SaveRequestDialog widget.""" + +from __future__ import annotations + +from PySide6.QtCore import Qt +from PySide6.QtWidgets import QApplication, QTreeWidgetItem + +from services.collection_service import CollectionService +from ui.dialogs.save_request_dialog import SaveRequestDialog, _tree_item_iterator + + +class TestSaveRequestDialogConstruction: + """Tests for initial dialog state.""" + + def test_construction(self, qapp: QApplication, qtbot) -> None: + """SaveRequestDialog can be instantiated without errors.""" + dialog = SaveRequestDialog() + qtbot.addWidget(dialog) + assert dialog.windowTitle() == "Save Request" + + def test_default_name(self, qapp: QApplication, qtbot) -> None: + """Default request name is 'Untitled Request'.""" + dialog = SaveRequestDialog() + qtbot.addWidget(dialog) + assert dialog.request_name() == "Untitled Request" + + def test_custom_default_name(self, qapp: QApplication, qtbot) -> None: + """Custom default name is shown in the name input.""" + dialog = SaveRequestDialog(default_name="http://example.com") + qtbot.addWidget(dialog) + assert dialog.request_name() == "http://example.com" + + def test_minimum_size(self, qapp: QApplication, qtbot) -> None: + """Dialog has a reasonable minimum size.""" + dialog = SaveRequestDialog() + qtbot.addWidget(dialog) + assert dialog.minimumWidth() >= 420 + assert dialog.minimumHeight() >= 460 + + def test_save_button_disabled_initially(self, qapp: QApplication, qtbot) -> None: + """Save button is disabled when no collection is selected.""" + dialog = SaveRequestDialog() + qtbot.addWidget(dialog) + assert not dialog._save_btn.isEnabled() + + def test_no_collection_selected_initially(self, qapp: QApplication, qtbot) -> None: + """No collection is selected on construction.""" + dialog = SaveRequestDialog() + qtbot.addWidget(dialog) + assert dialog.selected_collection_id() is None + + def test_tree_header_hidden(self, qapp: QApplication, qtbot) -> None: + """The tree widget has its header hidden.""" + dialog = SaveRequestDialog() + qtbot.addWidget(dialog) + assert dialog._tree.isHeaderHidden() + + +class TestSaveRequestDialogCollectionTree: + """Tests for the collection tree and selection.""" + + def test_collections_populated(self, qapp: QApplication, qtbot) -> None: + """Collections from the database appear in the tree.""" + CollectionService.create_collection("Alpha") + CollectionService.create_collection("Beta") + + dialog = SaveRequestDialog() + qtbot.addWidget(dialog) + + items = _tree_item_iterator(dialog._tree) + names = [it.text(0) for it in items] + assert "Alpha" in names + assert "Beta" in names + + def test_nested_collections_appear_as_children(self, qapp: QApplication, qtbot) -> None: + """Nested collections appear as child items in the tree.""" + parent = CollectionService.create_collection("Parent") + CollectionService.create_collection("Child", parent_id=parent.id) + + dialog = SaveRequestDialog() + qtbot.addWidget(dialog) + + # Find the parent item + parent_item: QTreeWidgetItem | None = None + for i in range(dialog._tree.topLevelItemCount()): + top = dialog._tree.topLevelItem(i) + if top is not None and top.data(0, Qt.ItemDataRole.UserRole) == parent.id: + parent_item = top + break + + assert parent_item is not None + assert parent_item.childCount() >= 1 + + def test_click_selects_collection(self, qapp: QApplication, qtbot) -> None: + """Clicking a tree item selects that collection.""" + coll = CollectionService.create_collection("ClickMe") + + dialog = SaveRequestDialog() + qtbot.addWidget(dialog) + + # Find the item for our collection + for item in _tree_item_iterator(dialog._tree): + if item.data(0, Qt.ItemDataRole.UserRole) == coll.id: + dialog._on_item_clicked(item, 0) + break + + assert dialog.selected_collection_id() == coll.id + assert dialog._save_btn.isEnabled() + + def test_search_filters_tree(self, qapp: QApplication, qtbot) -> None: + """Typing in the search field hides non-matching tree items.""" + CollectionService.create_collection("SearchTarget") + CollectionService.create_collection("OtherCollection") + + dialog = SaveRequestDialog() + qtbot.addWidget(dialog) + + dialog._search_input.setText("SearchTarget") + + visible = [it for it in _tree_item_iterator(dialog._tree) if not it.isHidden()] + assert len(visible) >= 1 + assert all("SearchTarget" in it.text(0) for it in visible) + + def test_search_clear_restores_tree(self, qapp: QApplication, qtbot) -> None: + """Clearing the search makes all items visible again.""" + CollectionService.create_collection("A") + CollectionService.create_collection("B") + + dialog = SaveRequestDialog() + qtbot.addWidget(dialog) + + dialog._search_input.setText("A") + dialog._search_input.clear() + + all_items = _tree_item_iterator(dialog._tree) + hidden = [it for it in all_items if it.isHidden()] + assert len(hidden) == 0 + + +class TestSaveRequestDialogNewCollection: + """Tests for the 'New Collection' inline creation.""" + + def test_new_collection_adds_to_tree(self, qapp: QApplication, qtbot) -> None: + """Clicking 'New Collection' creates one and adds it to the tree.""" + dialog = SaveRequestDialog() + qtbot.addWidget(dialog) + + count_before = len(_tree_item_iterator(dialog._tree)) + dialog._on_new_collection() + count_after = len(_tree_item_iterator(dialog._tree)) + + assert count_after == count_before + 1 + + def test_new_collection_auto_selects(self, qapp: QApplication, qtbot) -> None: + """The newly created collection is automatically selected.""" + dialog = SaveRequestDialog() + qtbot.addWidget(dialog) + + dialog._on_new_collection() + + assert dialog.selected_collection_id() is not None + assert dialog._save_btn.isEnabled() + + +class TestSaveRequestDialogRequestName: + """Tests for the request name input.""" + + def test_empty_name_returns_default(self, qapp: QApplication, qtbot) -> None: + """An empty name field falls back to 'Untitled Request'.""" + dialog = SaveRequestDialog() + qtbot.addWidget(dialog) + dialog._name_input.clear() + assert dialog.request_name() == "Untitled Request" + + def test_whitespace_name_returns_default(self, qapp: QApplication, qtbot) -> None: + """A whitespace-only name falls back to 'Untitled Request'.""" + dialog = SaveRequestDialog() + qtbot.addWidget(dialog) + dialog._name_input.setText(" ") + assert dialog.request_name() == "Untitled Request" + + def test_custom_name_returned(self, qapp: QApplication, qtbot) -> None: + """A user-entered name is returned as-is (stripped).""" + dialog = SaveRequestDialog() + qtbot.addWidget(dialog) + dialog._name_input.setText(" My Request ") + assert dialog.request_name() == "My Request" + assert dialog.request_name() == "My Request" diff --git a/tests/ui/request/navigation/test_request_tab_bar.py b/tests/ui/request/navigation/test_request_tab_bar.py index 5a99255..28c2989 100644 --- a/tests/ui/request/navigation/test_request_tab_bar.py +++ b/tests/ui/request/navigation/test_request_tab_bar.py @@ -2,6 +2,8 @@ from __future__ import annotations +from PySide6.QtCore import QPoint, Qt +from PySide6.QtGui import QMouseEvent from PySide6.QtWidgets import QApplication, QTabBar from ui.request.navigation.request_tab_bar import RequestTabBar @@ -145,6 +147,32 @@ def test_close_button_not_hidden(self, qapp: QApplication, qtbot) -> None: if close_btn is not None: assert not close_btn.isHidden() + def test_middle_click_closes_tab(self, qapp: QApplication, qtbot) -> None: + """Middle-clicking a tab emits tab_close_requested.""" + bar = RequestTabBar() + qtbot.addWidget(bar) + bar.add_request_tab("GET", "First") + bar.add_request_tab("POST", "Second") + bar.show() + bar.resize(400, 30) + qapp.processEvents() + + tab_rect = bar.tabRect(1) + pos = tab_rect.center() + + with qtbot.waitSignal(bar.tab_close_requested) as blocker: + event = QMouseEvent( + QMouseEvent.Type.MouseButtonPress, + QPoint(pos.x(), pos.y()), + bar.mapToGlobal(QPoint(pos.x(), pos.y())), + Qt.MouseButton.MiddleButton, + Qt.MouseButton.MiddleButton, + Qt.KeyboardModifier.NoModifier, + ) + bar.mousePressEvent(event) + + assert blocker.args == [1] + class TestMainWindowMultiTab: """Tests for multi-tab behaviour in MainWindow.""" diff --git a/tests/ui/request/navigation/test_tab_manager.py b/tests/ui/request/navigation/test_tab_manager.py index 82a2971..a8ce771 100644 --- a/tests/ui/request/navigation/test_tab_manager.py +++ b/tests/ui/request/navigation/test_tab_manager.py @@ -28,6 +28,7 @@ def test_construction_defaults(self, qapp: QApplication, qtbot) -> None: assert not ctx.is_dirty assert not ctx.is_sending assert not ctx.is_preview + assert ctx.draft_name is None def test_local_overrides_defaults_empty(self, qapp: QApplication, qtbot) -> None: """TabContext initialises with empty local_overrides dict.""" diff --git a/tests/ui/request/test_folder_editor.py b/tests/ui/request/test_folder_editor.py index 32b59ed..51ec0fa 100644 --- a/tests/ui/request/test_folder_editor.py +++ b/tests/ui/request/test_folder_editor.py @@ -137,13 +137,30 @@ def test_load_apikey_auth(self, qapp: QApplication, qtbot) -> None: assert editor._apikey_add_to_combo.currentText() == "Header" def test_load_no_auth(self, qapp: QApplication, qtbot) -> None: - """Loading with no auth defaults to No Auth.""" + """Loading with no auth defaults to Inherit auth from parent.""" editor = FolderEditorWidget() qtbot.addWidget(editor) editor.load_collection({"name": "Coll"}, collection_id=1) + assert editor._auth_type_combo.currentText() == "Inherit auth from parent" + + def test_load_explicit_noauth(self, qapp: QApplication, qtbot) -> None: + """Loading with explicit noauth selects No Auth.""" + editor = FolderEditorWidget() + qtbot.addWidget(editor) + + editor.load_collection({"name": "Coll", "auth": {"type": "noauth"}}, collection_id=1) assert editor._auth_type_combo.currentText() == "No Auth" + def test_get_inherit_auth_returns_none(self, qapp: QApplication, qtbot) -> None: + """get_collection_data returns auth=None for inherit.""" + editor = FolderEditorWidget() + qtbot.addWidget(editor) + + editor.load_collection({"name": "Coll"}, collection_id=1) + result = editor.get_collection_data() + assert result["auth"] is None + def test_load_scripts(self, qapp: QApplication, qtbot) -> None: """Loading populates the pre-request and test script fields.""" editor = FolderEditorWidget() @@ -248,7 +265,7 @@ def test_clear_resets_fields(self, qapp: QApplication, qtbot) -> None: assert editor._description_edit.toPlainText() == "" assert editor._pre_request_edit.toPlainText() == "" assert editor._test_script_edit.toPlainText() == "" - assert editor._auth_type_combo.currentText() == "No Auth" + assert editor._auth_type_combo.currentText() == "Inherit auth from parent" class TestFolderEditorSignal: diff --git a/tests/ui/request/test_request_editor_auth.py b/tests/ui/request/test_request_editor_auth.py index d623970..cc865e8 100644 --- a/tests/ui/request/test_request_editor_auth.py +++ b/tests/ui/request/test_request_editor_auth.py @@ -24,6 +24,7 @@ def test_auth_type_combo_has_options(self, qapp: QApplication, qtbot) -> None: items = [ editor._auth_type_combo.itemText(i) for i in range(editor._auth_type_combo.count()) ] + assert "Inherit auth from parent" in items assert "No Auth" in items assert "Bearer Token" in items assert "Basic Auth" in items @@ -86,12 +87,22 @@ def test_get_auth_data_bearer(self, qapp: QApplication, qtbot) -> None: assert data["auth"]["type"] == "bearer" assert data["auth"]["bearer"][0]["value"] == "abc" + def test_get_auth_data_inherit(self, qapp: QApplication, qtbot) -> None: + """get_request_data returns None when Inherit auth is selected.""" + editor = RequestEditorWidget() + qtbot.addWidget(editor) + + editor.load_request({"name": "X", "method": "GET", "url": "http://x"}) + data = editor.get_request_data() + assert data["auth"] is None + def test_get_auth_data_no_auth(self, qapp: QApplication, qtbot) -> None: """get_request_data returns noauth when No Auth is selected.""" editor = RequestEditorWidget() qtbot.addWidget(editor) editor.load_request({"name": "X", "method": "GET", "url": "http://x"}) + editor._auth_type_combo.setCurrentText("No Auth") data = editor.get_request_data() assert data["auth"]["type"] == "noauth" @@ -112,9 +123,32 @@ def test_clear_resets_auth(self, qapp: QApplication, qtbot) -> None: } ) editor.clear_request() - assert editor._auth_type_combo.currentText() == "No Auth" + assert editor._auth_type_combo.currentText() == "Inherit auth from parent" assert editor._bearer_token_input.text() == "" + def test_load_inherit_auth(self, qapp: QApplication, qtbot) -> None: + """Loading with auth=None selects Inherit auth from parent.""" + editor = RequestEditorWidget() + qtbot.addWidget(editor) + + editor.load_request({"name": "X", "method": "GET", "url": "http://x", "auth": None}) + assert editor._auth_type_combo.currentText() == "Inherit auth from parent" + + def test_load_noauth_selects_no_auth(self, qapp: QApplication, qtbot) -> None: + """Loading with explicit noauth selects No Auth.""" + editor = RequestEditorWidget() + qtbot.addWidget(editor) + + editor.load_request( + { + "name": "X", + "method": "GET", + "url": "http://x", + "auth": {"type": "noauth"}, + } + ) + assert editor._auth_type_combo.currentText() == "No Auth" + class TestApplyAuth: """Tests for HttpSendWorker._apply_auth static method.""" @@ -187,3 +221,137 @@ def test_noauth_no_modification(self) -> None: url, headers = HttpSendWorker._apply_auth({"type": "noauth"}, "http://x", "H: v", {}) assert url == "http://x" assert headers == "H: v" + + +class TestOAuth2PageWidget: + """Tests for the OAuth2Page custom widget.""" + + def test_load_and_get_entries_roundtrip(self, qapp: QApplication, qtbot) -> None: + """Loading entries and reading them back preserves values.""" + from ui.request.auth.oauth2_page import OAuth2Page + + page = OAuth2Page(on_change=lambda: None) + qtbot.addWidget(page) + + entries = [ + {"key": "accessToken", "value": "tok123"}, + {"key": "headerPrefix", "value": "Bearer"}, + {"key": "tokenName", "value": "My Token"}, + {"key": "addTokenTo", "value": "header"}, + {"key": "grant_type", "value": "client_credentials"}, + {"key": "accessTokenUrl", "value": "https://auth.test/token"}, + {"key": "clientId", "value": "cid"}, + {"key": "clientSecret", "value": "csec"}, + {"key": "scope", "value": "api"}, + {"key": "client_authentication", "value": "header"}, + ] + page.load(entries) + + result = page.get_entries() + result_map = {e["key"]: e["value"] for e in result} + + assert result_map["accessToken"] == "tok123" + assert result_map["headerPrefix"] == "Bearer" + assert result_map["tokenName"] == "My Token" + assert result_map["addTokenTo"] == "header" + assert result_map["grant_type"] == "client_credentials" + assert result_map["clientId"] == "cid" + + def test_grant_type_switching(self, qapp: QApplication, qtbot) -> None: + """Changing grant type shows correct field containers.""" + from ui.request.auth.oauth2_page import OAuth2Page + + page = OAuth2Page(on_change=lambda: None) + qtbot.addWidget(page) + + page._grant_type.setCurrentText("Password Credentials") + assert not page._grant_containers["Password Credentials"].isHidden() + assert page._grant_containers["Authorization Code"].isHidden() + + page._grant_type.setCurrentText("Implicit") + assert not page._grant_containers["Implicit"].isHidden() + assert page._grant_containers["Password Credentials"].isHidden() + + def test_get_config_returns_grant_type(self, qapp: QApplication, qtbot) -> None: + """get_config returns the correct grant_type key.""" + from ui.request.auth.oauth2_page import OAuth2Page + + page = OAuth2Page(on_change=lambda: None) + qtbot.addWidget(page) + + page._grant_type.setCurrentText("Client Credentials") + config = page.get_config() + assert config["grant_type"] == "client_credentials" + + def test_set_token(self, qapp: QApplication, qtbot) -> None: + """set_token populates the access token and display name.""" + from ui.request.auth.oauth2_page import OAuth2Page + + page = OAuth2Page(on_change=lambda: None) + qtbot.addWidget(page) + + page.set_token("new_token_value", "Test Token") + assert page._access_token.text() == "new_token_value" + assert page._token_name_display.text() == "Test Token" + + def test_clear_resets_all(self, qapp: QApplication, qtbot) -> None: + """clear() resets all fields to defaults.""" + from ui.request.auth.oauth2_page import OAuth2Page + + page = OAuth2Page(on_change=lambda: None) + qtbot.addWidget(page) + + page._access_token.setText("tok") + page._token_name.setText("name") + page.clear() + + assert page._access_token.text() == "" + assert page._token_name.text() == "" + assert page._header_prefix.text() == "Bearer" + + def test_load_oauth2_in_editor(self, qapp: QApplication, qtbot) -> None: + """Loading OAuth 2.0 auth in RequestEditor selects correct type.""" + editor = RequestEditorWidget() + qtbot.addWidget(editor) + + editor.load_request( + { + "name": "X", + "method": "GET", + "url": "http://x", + "auth": { + "type": "oauth2", + "oauth2": [ + {"key": "accessToken", "value": "mytoken"}, + {"key": "headerPrefix", "value": "Bearer"}, + {"key": "grant_type", "value": "client_credentials"}, + ], + }, + } + ) + assert editor._auth_type_combo.currentText() == "OAuth 2.0" + + def test_get_auth_data_oauth2(self, qapp: QApplication, qtbot) -> None: + """get_request_data returns OAuth 2.0 entries.""" + editor = RequestEditorWidget() + qtbot.addWidget(editor) + + editor.load_request( + { + "name": "X", + "method": "GET", + "url": "http://x", + "auth": { + "type": "oauth2", + "oauth2": [ + {"key": "accessToken", "value": "tok"}, + {"key": "grant_type", "value": "authorization_code"}, + ], + }, + } + ) + data = editor.get_request_data() + assert data["auth"]["type"] == "oauth2" + entries = data["auth"]["oauth2"] + entry_map = {e["key"]: e["value"] for e in entries} + assert entry_map["accessToken"] == "tok" diff --git a/tests/ui/sidebar/__init__.py b/tests/ui/sidebar/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/ui/sidebar/test_sidebar.py b/tests/ui/sidebar/test_sidebar.py new file mode 100644 index 0000000..a515814 --- /dev/null +++ b/tests/ui/sidebar/test_sidebar.py @@ -0,0 +1,180 @@ +"""Tests for the RightSidebar widget (icon rail + flyout panel).""" + +from __future__ import annotations + +from typing import Any + +from PySide6.QtWidgets import QApplication + +from ui.sidebar.sidebar_widget import RightSidebar +from ui.sidebar.snippet_panel import SnippetPanel +from ui.sidebar.variables_panel import VariablesPanel + + +class TestRightSidebar: + """Tests for the Postman-style icon-rail + flyout panel sidebar.""" + + def test_construction(self, qapp: QApplication, qtbot) -> None: + """RightSidebar can be instantiated without errors.""" + sidebar = RightSidebar() + qtbot.addWidget(sidebar) + assert sidebar.objectName() == "sidebarRail" + + def test_panels_exist(self, qapp: QApplication, qtbot) -> None: + """Sidebar exposes variables_panel and snippet_panel.""" + sidebar = RightSidebar() + qtbot.addWidget(sidebar) + assert isinstance(sidebar.variables_panel, VariablesPanel) + assert isinstance(sidebar.snippet_panel, SnippetPanel) + + def test_rail_buttons_exist(self, qapp: QApplication, qtbot) -> None: + """Sidebar has rail buttons for variables and code snippet.""" + sidebar = RightSidebar() + qtbot.addWidget(sidebar) + assert sidebar._var_btn is not None + assert sidebar._snippet_btn is not None + + def test_buttons_start_disabled(self, qapp: QApplication, qtbot) -> None: + """Rail buttons are disabled until a tab context is set.""" + sidebar = RightSidebar() + qtbot.addWidget(sidebar) + assert not sidebar._var_btn.isEnabled() + assert sidebar._snippet_btn.isHidden() + + def test_panel_starts_closed(self, qapp: QApplication, qtbot) -> None: + """No panel is open on construction.""" + sidebar = RightSidebar() + qtbot.addWidget(sidebar) + assert sidebar.active_panel is None + assert not sidebar.panel_open + + def test_open_panel_variables(self, qapp: QApplication, qtbot) -> None: + """open_panel('variables') opens the variables panel.""" + sidebar = RightSidebar() + qtbot.addWidget(sidebar) + sidebar.show_request_panels({}, method="GET", url="") + sidebar.open_panel("variables") + assert sidebar.active_panel == "variables" + assert sidebar.panel_open + assert sidebar._var_btn.isChecked() + assert not sidebar._snippet_btn.isChecked() + + def test_open_panel_snippet(self, qapp: QApplication, qtbot) -> None: + """open_panel('snippet') opens the snippet panel.""" + sidebar = RightSidebar() + qtbot.addWidget(sidebar) + sidebar.show_request_panels({}, method="GET", url="") + sidebar.open_panel("snippet") + assert sidebar.active_panel == "snippet" + assert sidebar.panel_open + assert not sidebar._var_btn.isChecked() + assert sidebar._snippet_btn.isChecked() + + def test_toggle_panel_closes_active(self, qapp: QApplication, qtbot) -> None: + """Clicking the active panel's icon closes the panel.""" + sidebar = RightSidebar() + qtbot.addWidget(sidebar) + sidebar.show_request_panels({}, method="GET", url="") + sidebar.open_panel("variables") + assert sidebar.active_panel == "variables" + + sidebar._toggle_panel("variables") + assert sidebar.active_panel is None + assert not sidebar.panel_open + + def test_toggle_panel_switches(self, qapp: QApplication, qtbot) -> None: + """Clicking a different icon switches the active panel.""" + sidebar = RightSidebar() + qtbot.addWidget(sidebar) + sidebar.show_request_panels({}, method="GET", url="") + sidebar.open_panel("variables") + + sidebar._toggle_panel("snippet") + assert sidebar.active_panel == "snippet" + assert sidebar._snippet_btn.isChecked() + assert not sidebar._var_btn.isChecked() + + def test_show_request_panels_enables_both( + self, + qapp: QApplication, + qtbot, + ) -> None: + """show_request_panels enables both rail icons.""" + sidebar = RightSidebar() + qtbot.addWidget(sidebar) + variables: dict[str, Any] = { + "key1": {"value": "val1", "source": "environment", "source_id": 1}, + } + sidebar.show_request_panels( + variables, + method="GET", + url="https://example.com", + ) + assert sidebar._var_btn.isEnabled() + assert not sidebar._snippet_btn.isHidden() + assert sidebar._snippet_btn.isEnabled() + + def test_show_folder_panels_disables_snippet( + self, + qapp: QApplication, + qtbot, + ) -> None: + """show_folder_panels enables variables but disables snippet.""" + sidebar = RightSidebar() + qtbot.addWidget(sidebar) + variables: dict[str, Any] = { + "key1": {"value": "val1", "source": "collection", "source_id": 5}, + } + sidebar.show_folder_panels(variables) + assert sidebar._var_btn.isEnabled() + assert sidebar._snippet_btn.isHidden() + + def test_show_folder_closes_snippet_panel( + self, + qapp: QApplication, + qtbot, + ) -> None: + """Switching to a folder closes the snippet panel if it was open.""" + sidebar = RightSidebar() + qtbot.addWidget(sidebar) + sidebar.show_request_panels({}, method="GET", url="") + sidebar.open_panel("snippet") + assert sidebar.active_panel == "snippet" + + sidebar.show_folder_panels({}) + assert sidebar.active_panel is None + + def test_clear_disables_all(self, qapp: QApplication, qtbot) -> None: + """clear() disables all icons and closes the panel.""" + sidebar = RightSidebar() + qtbot.addWidget(sidebar) + sidebar.show_request_panels({}, method="GET", url="") + sidebar.open_panel("variables") + sidebar.clear() + assert not sidebar._var_btn.isEnabled() + assert sidebar._snippet_btn.isHidden() + assert sidebar.active_panel is None + + def test_close_button_closes_panel(self, qapp: QApplication, qtbot) -> None: + """Clicking the close button hides the active panel.""" + sidebar = RightSidebar() + qtbot.addWidget(sidebar) + sidebar.show_request_panels({}, method="GET", url="") + sidebar.open_panel("variables") + assert sidebar.panel_open + + sidebar._close_btn.click() + assert not sidebar.panel_open + assert sidebar.active_panel is None + + def test_open_unavailable_panel_ignored( + self, + qapp: QApplication, + qtbot, + ) -> None: + """Opening a panel not in available_panels is a no-op.""" + sidebar = RightSidebar() + qtbot.addWidget(sidebar) + sidebar.show_folder_panels({}) + sidebar.open_panel("snippet") + assert sidebar.active_panel is None diff --git a/tests/ui/sidebar/test_snippet_panel.py b/tests/ui/sidebar/test_snippet_panel.py new file mode 100644 index 0000000..177acfd --- /dev/null +++ b/tests/ui/sidebar/test_snippet_panel.py @@ -0,0 +1,328 @@ +"""Tests for the SnippetPanel widget.""" + +from __future__ import annotations + +from PySide6.QtCore import QSettings +from PySide6.QtWidgets import QApplication + +from services.http.snippet_generator import SnippetGenerator +from ui.sidebar.snippet_panel import SnippetPanel, SnippetSettingsPopup + + +class TestSnippetPanel: + """Tests for the inline code snippet panel.""" + + def test_construction(self, qapp: QApplication, qtbot) -> None: + """SnippetPanel can be instantiated without errors.""" + panel = SnippetPanel() + qtbot.addWidget(panel) + assert panel is not None + + def test_default_language(self, qapp: QApplication, qtbot) -> None: + """The language combo starts with the first available language.""" + panel = SnippetPanel() + qtbot.addWidget(panel) + assert panel._lang_combo.currentText() == "cURL" + + def test_update_request_generates_snippet(self, qapp: QApplication, qtbot) -> None: + """update_request populates the code editor with a snippet.""" + panel = SnippetPanel() + qtbot.addWidget(panel) + panel.update_request(method="GET", url="https://api.example.com/users") + text = panel._code_edit.toPlainText() + assert "curl" in text.lower() + assert "https://api.example.com/users" in text + + def test_language_switch_regenerates(self, qapp: QApplication, qtbot) -> None: + """Switching the language combo regenerates the snippet.""" + panel = SnippetPanel() + qtbot.addWidget(panel) + panel.update_request(method="POST", url="https://api.example.com/data") + panel._lang_combo.setCurrentText("Python (requests)") + text = panel._code_edit.toPlainText() + assert "requests" in text.lower() + + def test_copy_to_clipboard(self, qapp: QApplication, qtbot) -> None: + """Clicking copy sets the text to the system clipboard.""" + panel = SnippetPanel() + qtbot.addWidget(panel) + panel.update_request(method="GET", url="https://example.com") + panel._copy_btn.click() + assert panel._status_label.text() == "Copied!" + + def test_clear_resets_state(self, qapp: QApplication, qtbot) -> None: + """clear() empties the code editor.""" + panel = SnippetPanel() + qtbot.addWidget(panel) + panel.update_request(method="GET", url="https://example.com") + panel.clear() + assert panel._code_edit.toPlainText() == "" + + def test_empty_url_no_snippet(self, qapp: QApplication, qtbot) -> None: + """When URL is empty, no snippet is generated.""" + panel = SnippetPanel() + qtbot.addWidget(panel) + panel.update_request(method="GET", url="") + assert panel._code_edit.toPlainText() == "" + + def test_snippet_with_headers_and_body(self, qapp: QApplication, qtbot) -> None: + """Snippet includes headers and body when provided.""" + panel = SnippetPanel() + qtbot.addWidget(panel) + panel.update_request( + method="POST", + url="https://api.example.com/data", + headers="Content-Type: application/json", + body='{"key": "value"}', + ) + text = panel._code_edit.toPlainText() + assert "Content-Type" in text + assert "key" in text + + def test_snippet_with_bearer_auth(self, qapp: QApplication, qtbot) -> None: + """Snippet includes Authorization header when bearer auth is set.""" + panel = SnippetPanel() + qtbot.addWidget(panel) + auth = {"type": "bearer", "bearer": [{"key": "token", "value": "tok123"}]} + panel.update_request( + method="GET", + url="https://api.example.com", + auth=auth, + ) + text = panel._code_edit.toPlainText() + assert "Authorization" in text + assert "Bearer tok123" in text + + def test_syntax_highlighting_language(self, qapp: QApplication, qtbot) -> None: + """Snippet editor uses correct syntax language per combo selection.""" + panel = SnippetPanel() + qtbot.addWidget(panel) + panel._lang_combo.setCurrentText("cURL") + panel.update_request(method="GET", url="https://example.com") + assert panel._code_edit._language == "bash" + + panel._lang_combo.setCurrentText("Python (requests)") + assert panel._code_edit._language == "python" + + panel._lang_combo.setCurrentText("JavaScript (fetch)") + assert panel._code_edit._language == "javascript" + + def test_all_languages_in_combo(self, qapp: QApplication, qtbot) -> None: + """Combo box contains all 23 registered languages.""" + panel = SnippetPanel() + qtbot.addWidget(panel) + count = panel._lang_combo.count() + assert count == len(SnippetGenerator.available_languages()) + assert count == 23 + + def test_gear_button_exists(self, qapp: QApplication, qtbot) -> None: + """Settings gear button is present in the panel.""" + panel = SnippetPanel() + qtbot.addWidget(panel) + assert panel._settings_btn is not None + assert panel._settings_btn.toolTip() == "Snippet settings" + + def test_toggle_settings_opens_popup(self, qapp: QApplication, qtbot) -> None: + """Clicking gear button creates and shows the settings popup.""" + panel = SnippetPanel() + qtbot.addWidget(panel) + assert panel._settings_popup is None + panel._toggle_settings() + assert panel._settings_popup is not None + assert panel._settings_popup.isVisible() + + def test_toggle_settings_closes_popup(self, qapp: QApplication, qtbot) -> None: + """Second click on gear button hides the popup.""" + panel = SnippetPanel() + qtbot.addWidget(panel) + panel._toggle_settings() + assert panel._settings_popup is not None + assert panel._settings_popup.isVisible() + panel._toggle_settings() + assert not panel._settings_popup.isVisible() + + def test_options_propagated_to_generate(self, qapp: QApplication, qtbot) -> None: + """Options from the settings popup affect snippet generation.""" + panel = SnippetPanel() + qtbot.addWidget(panel) + panel._toggle_settings() + assert panel._settings_popup is not None + panel._settings_popup._indent_count.setValue(4) + panel.update_request(method="GET", url="https://api.example.com") + # Re-open should preserve indent count + assert panel._settings_popup._indent_count.value() == 4 + + def test_language_persisted_to_qsettings(self, qapp: QApplication, qtbot) -> None: + """Switching language persists it via QSettings.""" + panel = SnippetPanel() + qtbot.addWidget(panel) + panel._lang_combo.setCurrentText("Go (net/http)") + saved = QSettings().value("snippet/last_language") + assert saved == "Go (net/http)" + + def test_lexer_from_registry(self, qapp: QApplication, qtbot) -> None: + """Each language combo entry sets the correct lexer via registry.""" + panel = SnippetPanel() + qtbot.addWidget(panel) + panel.update_request(method="GET", url="https://example.com") + # Python (requests) -> python lexer + panel._lang_combo.setCurrentText("Python (requests)") + assert panel._code_edit._language == "python" + # Go (net/http) -> go lexer + panel._lang_combo.setCurrentText("Go (net/http)") + assert panel._code_edit._language == "go" + # Rust (reqwest) -> rust lexer + panel._lang_combo.setCurrentText("Rust (reqwest)") + assert panel._code_edit._language == "rust" + + +class TestSnippetSettingsPopup: + """Tests for the SnippetSettingsPopup widget.""" + + def test_construction(self, qapp: QApplication, qtbot) -> None: + """Popup can be instantiated.""" + popup = SnippetSettingsPopup() + qtbot.addWidget(popup) + assert popup is not None + + def test_get_options_defaults(self, qapp: QApplication, qtbot) -> None: + """Default options match expected defaults.""" + QSettings().remove("snippet") + popup = SnippetSettingsPopup() + qtbot.addWidget(popup) + opts = popup.get_options() + assert opts["indent_count"] == 2 + assert opts["indent_type"] == "space" + assert opts["trim_body"] is False + assert opts["follow_redirect"] is True + assert opts["request_timeout"] == 0 + assert opts["include_boilerplate"] is True + assert opts["async_await"] is False + assert opts["es6_features"] is False + + def test_set_language_options_hides_controls(self, qapp: QApplication, qtbot) -> None: + """Controls not in applicable_options are hidden.""" + popup = SnippetSettingsPopup() + qtbot.addWidget(popup) + popup.set_language_options(("trim_body",)) + assert popup._indent_count.isHidden() + assert popup._indent_type.isHidden() + assert popup._follow_redirect.isHidden() + assert popup._request_timeout.isHidden() + assert popup._include_boilerplate.isHidden() + assert popup._async_await.isHidden() + assert popup._es6_features.isHidden() + assert not popup._trim_body.isHidden() + + def test_set_language_options_shows_controls(self, qapp: QApplication, qtbot) -> None: + """Controls in applicable_options are not hidden.""" + popup = SnippetSettingsPopup() + qtbot.addWidget(popup) + popup.set_language_options( + ( + "indent_count", + "indent_type", + "trim_body", + "follow_redirect", + "request_timeout", + "include_boilerplate", + "async_await", + "es6_features", + ) + ) + assert not popup._indent_count.isHidden() + assert not popup._indent_type.isHidden() + assert not popup._trim_body.isHidden() + assert not popup._follow_redirect.isHidden() + assert not popup._request_timeout.isHidden() + assert not popup._include_boilerplate.isHidden() + assert not popup._async_await.isHidden() + assert not popup._es6_features.isHidden() + + def test_set_language_options_httpie_hides_indent(self, qapp: QApplication, qtbot) -> None: + """HTTPie options hide indent controls.""" + popup = SnippetSettingsPopup() + qtbot.addWidget(popup) + popup.set_language_options(("request_timeout", "follow_redirect")) + assert popup._indent_count.isHidden() + assert popup._indent_type.isHidden() + assert popup._trim_body.isHidden() + assert not popup._request_timeout.isHidden() + assert not popup._follow_redirect.isHidden() + + def test_on_settings_changed_callback(self, qapp: QApplication, qtbot) -> None: + """Changing a value fires the on_settings_changed callback.""" + popup = SnippetSettingsPopup() + qtbot.addWidget(popup) + called = [] + popup.on_settings_changed(lambda: called.append(True)) + popup._indent_count.setValue(4) + assert len(called) == 1 + + def test_saves_to_qsettings(self, qapp: QApplication, qtbot) -> None: + """Changing a setting persists the value via QSettings.""" + popup = SnippetSettingsPopup() + qtbot.addWidget(popup) + popup._indent_count.setValue(6) + saved = int(str(QSettings().value("snippet/indent_count", 2))) + assert saved == 6 + + def test_get_options_includes_new_curl_fields(self, qapp: QApplication, qtbot) -> None: + """get_options returns all 14 fields including 6 new cURL options.""" + QSettings().remove("snippet") + popup = SnippetSettingsPopup() + qtbot.addWidget(popup) + opts = popup.get_options() + assert opts["multiline"] is True + assert opts["long_form"] is True + assert opts["line_continuation"] == "\\" + assert opts["quote_type"] == "single" + assert opts["follow_original_method"] is False + assert opts["silent_mode"] is False + + def test_curl_options_visible_for_curl(self, qapp: QApplication, qtbot) -> None: + """cURL-specific controls are visible when cURL options are set.""" + popup = SnippetSettingsPopup() + qtbot.addWidget(popup) + popup.set_language_options( + ( + "trim_body", + "request_timeout", + "follow_redirect", + "follow_original_method", + "multiline", + "long_form", + "line_continuation", + "quote_type", + "silent_mode", + ) + ) + assert not popup._multiline.isHidden() + assert not popup._long_form.isHidden() + assert not popup._line_continuation.isHidden() + assert not popup._quote_type.isHidden() + assert not popup._follow_original_method.isHidden() + assert not popup._silent_mode.isHidden() + + def test_curl_options_hidden_for_python(self, qapp: QApplication, qtbot) -> None: + """cURL-specific controls are hidden for non-cURL languages.""" + popup = SnippetSettingsPopup() + qtbot.addWidget(popup) + popup.set_language_options( + ("indent_count", "indent_type", "trim_body", "request_timeout", "follow_redirect") + ) + assert popup._multiline.isHidden() + assert popup._long_form.isHidden() + assert popup._line_continuation.isHidden() + assert popup._quote_type.isHidden() + assert popup._follow_original_method.isHidden() + assert popup._silent_mode.isHidden() + + def test_new_curl_options_change_fires_callback(self, qapp: QApplication, qtbot) -> None: + """Changing a cURL option fires on_settings_changed callback.""" + popup = SnippetSettingsPopup() + qtbot.addWidget(popup) + called = [] + popup.on_settings_changed(lambda: called.append(True)) + popup._multiline.setChecked(False) + assert len(called) >= 1 diff --git a/tests/ui/sidebar/test_variables_panel.py b/tests/ui/sidebar/test_variables_panel.py new file mode 100644 index 0000000..6dfd426 --- /dev/null +++ b/tests/ui/sidebar/test_variables_panel.py @@ -0,0 +1,132 @@ +"""Tests for the VariablesPanel widget.""" + +from __future__ import annotations + +from typing import Any + +from PySide6.QtWidgets import QApplication, QLabel + +from ui.sidebar.variables_panel import VariablesPanel + + +class TestVariablesPanel: + """Tests for the read-only variables display panel.""" + + def test_construction(self, qapp: QApplication, qtbot) -> None: + """VariablesPanel can be instantiated with empty state.""" + panel = VariablesPanel() + qtbot.addWidget(panel) + assert panel is not None + + def test_empty_state_label(self, qapp: QApplication, qtbot) -> None: + """Panel shows 'No variables available' when empty.""" + panel = VariablesPanel() + qtbot.addWidget(panel) + labels = panel._content.findChildren(QLabel) + texts = [lbl.text() for lbl in labels] + assert "No variables available" in texts + + def test_load_environment_variables(self, qapp: QApplication, qtbot) -> None: + """Panel renders environment variables under the Environment section.""" + panel = VariablesPanel() + qtbot.addWidget(panel) + variables: dict[str, Any] = { + "api_key": {"value": "abc123", "source": "environment", "source_id": 1}, + } + panel.load_variables(variables, has_environment=True) + labels = panel._content.findChildren(QLabel) + texts = [lbl.text() for lbl in labels] + assert "api_key" in texts + assert "abc123" in texts + # Section header should be present + assert "Environment" in texts + + def test_load_collection_variables(self, qapp: QApplication, qtbot) -> None: + """Panel renders collection variables under the Collection section.""" + panel = VariablesPanel() + qtbot.addWidget(panel) + variables: dict[str, Any] = { + "base_url": { + "value": "https://api.example.com", + "source": "collection", + "source_id": 5, + }, + } + panel.load_variables(variables, has_environment=True) + labels = panel._content.findChildren(QLabel) + texts = [lbl.text() for lbl in labels] + assert "base_url" in texts + assert "Requests collection" in texts + + def test_no_environment_message(self, qapp: QApplication, qtbot) -> None: + """Panel shows 'No environment selected' when has_environment=False.""" + panel = VariablesPanel() + qtbot.addWidget(panel) + panel.load_variables({}, has_environment=False) + labels = panel._content.findChildren(QLabel) + texts = [lbl.text() for lbl in labels] + assert any("No environment selected" in t for t in texts) + + def test_local_overrides_section(self, qapp: QApplication, qtbot) -> None: + """Panel shows local overrides in a separate section.""" + panel = VariablesPanel() + qtbot.addWidget(panel) + variables: dict[str, Any] = { + "token": {"value": "original", "source": "environment", "source_id": 1}, + } + overrides: dict[str, Any] = { + "token": { + "value": "local_val", + "original_source": "environment", + "original_source_id": 1, + }, + } + panel.load_variables(variables, local_overrides=overrides, has_environment=True) + labels = panel._content.findChildren(QLabel) + texts = [lbl.text() for lbl in labels] + assert "Local overrides" in texts + assert "local_val" in texts + + def test_clear_resets_to_empty(self, qapp: QApplication, qtbot) -> None: + """clear() returns to the empty state.""" + panel = VariablesPanel() + qtbot.addWidget(panel) + variables: dict[str, Any] = { + "x": {"value": "y", "source": "environment", "source_id": 1}, + } + panel.load_variables(variables, has_environment=True) + panel.clear() + labels = panel._content.findChildren(QLabel) + texts = [lbl.text() for lbl in labels] + assert "No variables available" in texts + + def test_grouping_multiple_sources(self, qapp: QApplication, qtbot) -> None: + """Panel groups variables by source correctly.""" + panel = VariablesPanel() + qtbot.addWidget(panel) + variables: dict[str, Any] = { + "env_var": {"value": "e_val", "source": "environment", "source_id": 1}, + "coll_var": {"value": "c_val", "source": "collection", "source_id": 5}, + } + panel.load_variables(variables, has_environment=True) + labels = panel._content.findChildren(QLabel) + texts = [lbl.text() for lbl in labels] + assert "Environment" in texts + assert "Requests collection" in texts + assert "env_var" in texts + assert "coll_var" in texts + + def test_long_value_has_tooltip(self, qapp: QApplication, qtbot) -> None: + """Long values keep the full text available as a tooltip.""" + panel = VariablesPanel() + qtbot.addWidget(panel) + long_val = "a" * 60 + variables: dict[str, Any] = { + "long_key": {"value": long_val, "source": "environment", "source_id": 1}, + } + panel.load_variables(variables, has_environment=True) + labels = panel._content.findChildren(QLabel) + value_labels = [lbl for lbl in labels if lbl.objectName() == "variableValueLabel"] + assert len(value_labels) == 1 + assert value_labels[0].text() == long_val + assert value_labels[0].toolTip() == long_val diff --git a/tests/ui/test_main_window.py b/tests/ui/test_main_window.py index 5ff2745..c5173e6 100644 --- a/tests/ui/test_main_window.py +++ b/tests/ui/test_main_window.py @@ -457,6 +457,82 @@ def test_close_all(self, qapp: QApplication, qtbot) -> None: assert window._tab_bar.count() == 0 +class TestMainWindowRightSidebar: + """Tests for the right sidebar icon rail and flyout panels.""" + + def test_toggle_right_sidebar(self, qapp: QApplication, qtbot) -> None: + """Toggling the right sidebar opens and closes the panel.""" + window = MainWindow() + qtbot.addWidget(window) + # Sidebar starts with no panel open + assert not window._right_sidebar.panel_open + + # Need a tab context to enable panels + svc = CollectionService() + coll = svc.create_collection("C") + req = svc.create_request(coll.id, "GET", "http://x", "R") + window._open_request(req.id, push_history=True) + + # No auto-open; sidebar should still be closed. + assert not window._right_sidebar.panel_open + window._toggle_right_sidebar() + assert window._right_sidebar.panel_open + window._toggle_right_sidebar() + assert not window._right_sidebar.panel_open + + def test_sidebar_rail_always_visible(self, qapp: QApplication, qtbot) -> None: + """The icon rail is not hidden — it is always present in the layout.""" + window = MainWindow() + qtbot.addWidget(window) + # The sidebar itself is not explicitly hidden + assert not window._right_sidebar.isHidden() + # The rail inside the sidebar is not hidden either + assert not window._right_sidebar._rail.isHidden() + + def test_sidebar_shows_request_panels_on_tab_switch(self, qapp: QApplication, qtbot) -> None: + """Switching to a request tab enables both sidebar icons.""" + svc = CollectionService() + coll = svc.create_collection("C") + req = svc.create_request(coll.id, "GET", "http://example.com", "R") + + window = MainWindow() + qtbot.addWidget(window) + + window._open_request(req.id, push_history=True) + + assert window._right_sidebar._var_btn.isEnabled() + assert not window._right_sidebar._snippet_btn.isHidden() + + def test_sidebar_shows_folder_panels_on_tab_switch(self, qapp: QApplication, qtbot) -> None: + """Switching to a folder tab enables variables but disables snippet.""" + svc = CollectionService() + coll = svc.create_collection("Folder") + + window = MainWindow() + qtbot.addWidget(window) + + window._open_folder(coll.id) + + assert window._right_sidebar._var_btn.isEnabled() + assert window._right_sidebar._snippet_btn.isHidden() + + def test_sidebar_clears_when_no_tab(self, qapp: QApplication, qtbot) -> None: + """Closing all tabs disables sidebar icons.""" + svc = CollectionService() + coll = svc.create_collection("C") + req = svc.create_request(coll.id, "GET", "http://x", "R") + + window = MainWindow() + qtbot.addWidget(window) + + window._open_request(req.id, push_history=True) + assert window._right_sidebar._var_btn.isEnabled() + + window._on_tab_close(0) + assert not window._right_sidebar._var_btn.isEnabled() + assert window._right_sidebar._snippet_btn.isHidden() + + class TestMainWindowFolderTabs: """Tests for opening and closing folder tabs.""" diff --git a/tests/ui/test_main_window_draft.py b/tests/ui/test_main_window_draft.py new file mode 100644 index 0000000..88b3295 --- /dev/null +++ b/tests/ui/test_main_window_draft.py @@ -0,0 +1,209 @@ +"""Tests for draft request tab lifecycle (open, save, upgrade).""" + +from __future__ import annotations + +from unittest.mock import MagicMock, patch + +from PySide6.QtWidgets import QApplication + +from services.collection_service import CollectionService +from ui.main_window import MainWindow + + +class TestOpenDraftRequest: + """Tests for ``_open_draft_request`` — creating an unsaved tab.""" + + def test_draft_tab_created(self, qapp: QApplication, qtbot) -> None: + """Opening a draft creates a new tab in the tab bar.""" + window = MainWindow() + qtbot.addWidget(window) + + assert window._tab_bar.count() == 0 + window._open_draft_request() + assert window._tab_bar.count() == 1 + + def test_draft_tab_name(self, qapp: QApplication, qtbot) -> None: + """Draft tab is labelled 'Untitled Request'.""" + window = MainWindow() + qtbot.addWidget(window) + + window._open_draft_request() + assert window._tab_bar.tabToolTip(0) == "Untitled Request" + + def test_draft_tab_has_no_request_id(self, qapp: QApplication, qtbot) -> None: + """Draft tab context has ``request_id=None``.""" + window = MainWindow() + qtbot.addWidget(window) + + window._open_draft_request() + ctx = window._tabs[0] + assert ctx.request_id is None + + def test_draft_editor_is_dirty(self, qapp: QApplication, qtbot) -> None: + """Draft editor starts dirty so the Save button is enabled.""" + window = MainWindow() + qtbot.addWidget(window) + + window._open_draft_request() + ctx = window._tabs[0] + assert ctx.editor is not None + assert ctx.editor.is_dirty + + def test_draft_save_btn_enabled(self, qapp: QApplication, qtbot) -> None: + """Save button is enabled after opening a draft tab.""" + window = MainWindow() + qtbot.addWidget(window) + + window._open_draft_request() + assert window._save_btn.isEnabled() + + def test_multiple_drafts(self, qapp: QApplication, qtbot) -> None: + """Multiple draft tabs can be opened.""" + window = MainWindow() + qtbot.addWidget(window) + + window._open_draft_request() + window._open_draft_request() + assert window._tab_bar.count() == 2 + + def test_draft_editor_empty_url(self, qapp: QApplication, qtbot) -> None: + """Draft editor has an empty URL field.""" + window = MainWindow() + qtbot.addWidget(window) + + window._open_draft_request() + ctx = window._tabs[0] + assert ctx.editor is not None + assert ctx.editor._url_input.text() == "" + + def test_draft_editor_get_method(self, qapp: QApplication, qtbot) -> None: + """Draft editor defaults to GET method.""" + window = MainWindow() + qtbot.addWidget(window) + + window._open_draft_request() + ctx = window._tabs[0] + assert ctx.editor is not None + assert ctx.editor._method_combo.currentText() == "GET" + + def test_draft_breadcrumb_shows_untitled(self, qapp: QApplication, qtbot) -> None: + """Opening a draft shows 'Untitled Request' in the breadcrumb bar.""" + window = MainWindow() + qtbot.addWidget(window) + + window._open_draft_request() + seg = window._breadcrumb_bar.last_segment_info + assert seg is not None + assert seg["name"] == "Untitled Request" + + def test_draft_context_has_draft_name(self, qapp: QApplication, qtbot) -> None: + """Draft tab context stores the draft_name.""" + window = MainWindow() + qtbot.addWidget(window) + + window._open_draft_request() + ctx = window._tabs[0] + assert ctx.draft_name == "Untitled Request" + + def test_draft_breadcrumb_rename_updates_tab(self, qapp: QApplication, qtbot) -> None: + """Renaming via breadcrumb updates the tab label and context.""" + window = MainWindow() + qtbot.addWidget(window) + + window._open_draft_request() + ctx = window._tabs[0] + + # Simulate breadcrumb rename + window._on_breadcrumb_rename("My Custom Request") + + assert ctx.draft_name == "My Custom Request" + assert window._tab_bar.tabToolTip(0) == "My Custom Request" + + +class TestSaveDraftRequest: + """Tests for ``_save_draft_request`` — persisting a draft to a collection.""" + + def test_save_draft_creates_request_in_db(self, qapp: QApplication, qtbot) -> None: + """Saving a draft creates a new request in the database.""" + coll = CollectionService.create_collection("TestColl") + window = MainWindow() + qtbot.addWidget(window) + + window._open_draft_request() + ctx = window._tabs[0] + assert ctx.editor is not None + ctx.editor._url_input.setText("http://draft.test") + ctx.editor._method_combo.setCurrentText("POST") + + mock_dialog = MagicMock() + mock_dialog.exec.return_value = mock_dialog.DialogCode.Accepted + mock_dialog.request_name.return_value = "My Draft" + mock_dialog.selected_collection_id.return_value = coll.id + + with patch( + "ui.dialogs.save_request_dialog.SaveRequestDialog", + return_value=mock_dialog, + ): + window._on_save_request() + + # Tab should now have a real request_id + assert ctx.request_id is not None + + # Verify in DB + saved = CollectionService.get_request(ctx.request_id) + assert saved is not None + assert saved.url == "http://draft.test" + assert saved.method == "POST" + + def test_save_draft_clears_dirty(self, qapp: QApplication, qtbot) -> None: + """After saving a draft, the editor is no longer dirty.""" + coll = CollectionService.create_collection("TestColl") + window = MainWindow() + qtbot.addWidget(window) + + window._open_draft_request() + ctx = window._tabs[0] + assert ctx.editor is not None + assert ctx.editor.is_dirty + + mock_dialog = MagicMock() + mock_dialog.exec.return_value = mock_dialog.DialogCode.Accepted + mock_dialog.request_name.return_value = "Saved" + mock_dialog.selected_collection_id.return_value = coll.id + + with patch( + "ui.dialogs.save_request_dialog.SaveRequestDialog", + return_value=mock_dialog, + ): + window._on_save_request() + + assert not ctx.editor.is_dirty + + def test_save_draft_cancelled_keeps_draft(self, qapp: QApplication, qtbot) -> None: + """Cancelling the save dialog keeps the tab as a draft.""" + window = MainWindow() + qtbot.addWidget(window) + + window._open_draft_request() + ctx = window._tabs[0] + + mock_dialog = MagicMock() + mock_dialog.exec.return_value = mock_dialog.DialogCode.Rejected + + with patch( + "ui.dialogs.save_request_dialog.SaveRequestDialog", + return_value=mock_dialog, + ): + window._on_save_request() + + assert ctx.request_id is None + + def test_save_noop_without_tab(self, qapp: QApplication, qtbot) -> None: + """Save does nothing when no tab is open and editor has no request_id.""" + window = MainWindow() + qtbot.addWidget(window) + window.request_widget._url_input.setText("http://whatever") + + with patch.object(CollectionService, "update_request") as mock_update: + window._on_save_request() + mock_update.assert_not_called() diff --git a/tests/unit/services/http/test_auth_handler.py b/tests/unit/services/http/test_auth_handler.py new file mode 100644 index 0000000..386a33e --- /dev/null +++ b/tests/unit/services/http/test_auth_handler.py @@ -0,0 +1,611 @@ +"""Tests for the shared auth header injection handler.""" + +from __future__ import annotations + +import base64 +import json + +from services.http.auth_handler import apply_auth + + +class TestApplyAuthNoop: + """Verify no-op scenarios return inputs unchanged.""" + + def test_none_auth(self) -> None: + """None auth returns url and headers unchanged.""" + url, hdr = apply_auth(None, "https://x.io", {}) + assert url == "https://x.io" + assert hdr == {} + + def test_noauth_type(self) -> None: + """Explicit 'noauth' returns inputs unchanged.""" + url, hdr = apply_auth({"type": "noauth"}, "https://x.io", {"A": "1"}) + assert url == "https://x.io" + assert hdr == {"A": "1"} + + def test_unknown_type(self) -> None: + """Unknown auth type returns inputs unchanged.""" + url, hdr = apply_auth({"type": "custom123"}, "https://x.io", {}) + assert url == "https://x.io" + assert hdr == {} + + +class TestBearerAuth: + """Bearer token injection.""" + + def test_adds_header(self) -> None: + """Token is added as Authorization: Bearer header.""" + auth = {"type": "bearer", "bearer": [{"key": "token", "value": "abc"}]} + _, hdr = apply_auth(auth, "https://x.io", {}) + assert hdr["Authorization"] == "Bearer abc" + + def test_empty_token(self) -> None: + """Empty token does not add a header.""" + auth = {"type": "bearer", "bearer": [{"key": "token", "value": ""}]} + _, hdr = apply_auth(auth, "https://x.io", {}) + assert "Authorization" not in hdr + + +class TestBasicAuth: + """Basic auth injection.""" + + def test_adds_header(self) -> None: + """Encodes username:password in base64.""" + auth = { + "type": "basic", + "basic": [ + {"key": "username", "value": "user"}, + {"key": "password", "value": "pass"}, + ], + } + _, hdr = apply_auth(auth, "https://x.io", {}) + expected = base64.b64encode(b"user:pass").decode() + assert hdr["Authorization"] == f"Basic {expected}" + + def test_empty_creds(self) -> None: + """Empty username and password produces no header.""" + auth = { + "type": "basic", + "basic": [ + {"key": "username", "value": ""}, + {"key": "password", "value": ""}, + ], + } + _, hdr = apply_auth(auth, "https://x.io", {}) + assert "Authorization" not in hdr + + +class TestApiKeyAuth: + """API key header and query parameter injection.""" + + def test_header(self) -> None: + """API key is added as a custom header.""" + auth = { + "type": "apikey", + "apikey": [ + {"key": "key", "value": "X-Api-Key"}, + {"key": "value", "value": "secret"}, + {"key": "in", "value": "header"}, + ], + } + _, hdr = apply_auth(auth, "https://x.io", {}) + assert hdr["X-Api-Key"] == "secret" + + def test_query_param(self) -> None: + """API key is appended as a URL query parameter.""" + auth = { + "type": "apikey", + "apikey": [ + {"key": "key", "value": "api_key"}, + {"key": "value", "value": "secret"}, + {"key": "in", "value": "query"}, + ], + } + url, _ = apply_auth(auth, "https://x.io/path", {}) + assert "api_key=secret" in url + assert "?" in url + + def test_query_param_appends_with_ampersand(self) -> None: + """When URL already has query params, uses & separator.""" + auth = { + "type": "apikey", + "apikey": [ + {"key": "key", "value": "k"}, + {"key": "value", "value": "v"}, + {"key": "in", "value": "query"}, + ], + } + url, _ = apply_auth(auth, "https://x.io/path?a=1", {}) + assert "?a=1&k=v" in url + + +class TestOAuth2Auth: + """OAuth 2.0 manual token injection.""" + + def test_header_with_prefix(self) -> None: + """Token is added with the configured prefix.""" + auth = { + "type": "oauth2", + "oauth2": [ + {"key": "accessToken", "value": "tok"}, + {"key": "headerPrefix", "value": "Bearer"}, + {"key": "addTokenTo", "value": "header"}, + ], + } + _, hdr = apply_auth(auth, "https://x.io", {}) + assert hdr["Authorization"] == "Bearer tok" + + def test_query_param(self) -> None: + """Token is added as query param when configured.""" + auth = { + "type": "oauth2", + "oauth2": [ + {"key": "accessToken", "value": "tok"}, + {"key": "addTokenTo", "value": "queryParams"}, + ], + } + url, _ = apply_auth(auth, "https://x.io", {}) + assert "access_token=tok" in url + + +class TestDigestAuth: + """Digest auth header generation (RFC 7616).""" + + def test_md5_auth(self) -> None: + """Produces a valid Digest header with MD5 algorithm.""" + auth = { + "type": "digest", + "digest": [ + {"key": "username", "value": "user"}, + {"key": "password", "value": "pass"}, + {"key": "realm", "value": "testrealm"}, + {"key": "nonce", "value": "abc123"}, + {"key": "algorithm", "value": "MD5"}, + {"key": "qop", "value": "auth"}, + {"key": "nonceCount", "value": "00000001"}, + {"key": "clientNonce", "value": "deadbeef"}, + {"key": "opaque", "value": "opq"}, + ], + } + _, hdr = apply_auth(auth, "https://x.io/path", {}, method="GET") + val = hdr["Authorization"] + assert val.startswith("Digest ") + assert 'username="user"' in val + assert 'realm="testrealm"' in val + assert "algorithm=MD5" in val + assert "qop=auth" in val + assert 'opaque="opq"' in val + + def test_no_qop(self) -> None: + """Digest without qop omits nc and cnonce from header.""" + auth = { + "type": "digest", + "digest": [ + {"key": "username", "value": "u"}, + {"key": "password", "value": "p"}, + {"key": "realm", "value": "r"}, + {"key": "nonce", "value": "n"}, + {"key": "algorithm", "value": "MD5"}, + ], + } + _, hdr = apply_auth(auth, "https://x.io/", {}, method="GET") + val = hdr["Authorization"] + assert "qop=" not in val + assert "nc=" not in val + + +class TestOAuth1Auth: + """OAuth 1.0 signature generation (RFC 5849).""" + + def test_hmac_sha1_header(self) -> None: + """Produces Authorization: OAuth header with signature.""" + auth = { + "type": "oauth1", + "oauth1": [ + {"key": "consumerKey", "value": "ck"}, + {"key": "consumerSecret", "value": "cs"}, + {"key": "token", "value": "tok"}, + {"key": "tokenSecret", "value": "ts"}, + {"key": "signatureMethod", "value": "HMAC-SHA1"}, + {"key": "timestamp", "value": "1234567890"}, + {"key": "nonce", "value": "testnonce"}, + {"key": "version", "value": "1.0"}, + {"key": "addParamsToHeader", "value": True}, + ], + } + _, hdr = apply_auth(auth, "https://x.io/api", {}, method="GET") + val = hdr["Authorization"] + assert val.startswith("OAuth ") + assert "oauth_consumer_key" in val + assert "oauth_signature" in val + + def test_plaintext_signature(self) -> None: + """PLAINTEXT signature is consumer_secret&token_secret.""" + auth = { + "type": "oauth1", + "oauth1": [ + {"key": "consumerKey", "value": "ck"}, + {"key": "consumerSecret", "value": "cs"}, + {"key": "token", "value": ""}, + {"key": "tokenSecret", "value": "ts"}, + {"key": "signatureMethod", "value": "PLAINTEXT"}, + {"key": "timestamp", "value": "0"}, + {"key": "nonce", "value": "n"}, + {"key": "addParamsToHeader", "value": True}, + ], + } + _, hdr = apply_auth(auth, "https://x.io", {}, method="GET") + assert "cs%26ts" in hdr["Authorization"] or "cs&ts" in hdr["Authorization"] + + def test_query_params(self) -> None: + """OAuth params appended to URL when addParamsToHeader is false.""" + auth = { + "type": "oauth1", + "oauth1": [ + {"key": "consumerKey", "value": "ck"}, + {"key": "consumerSecret", "value": "cs"}, + {"key": "token", "value": ""}, + {"key": "tokenSecret", "value": ""}, + {"key": "signatureMethod", "value": "PLAINTEXT"}, + {"key": "timestamp", "value": "0"}, + {"key": "nonce", "value": "n"}, + {"key": "addParamsToHeader", "value": False}, + ], + } + url, hdr = apply_auth(auth, "https://x.io", {}, method="GET") + assert "oauth_consumer_key=ck" in url + assert "Authorization" not in hdr + + +class TestHawkAuth: + """Hawk authentication header generation.""" + + def test_basic_hawk(self) -> None: + """Produces Authorization: Hawk header with required fields.""" + auth = { + "type": "hawk", + "hawk": [ + {"key": "authId", "value": "myid"}, + {"key": "authKey", "value": "mykey"}, + {"key": "algorithm", "value": "sha256"}, + {"key": "nonce", "value": "testnonce"}, + {"key": "timestamp", "value": "1234567890"}, + ], + } + _, hdr = apply_auth(auth, "https://x.io/path", {}, method="GET") + val = hdr["Authorization"] + assert val.startswith("Hawk ") + assert 'id="myid"' in val + assert 'ts="1234567890"' in val + assert 'nonce="testnonce"' in val + assert 'mac="' in val + + def test_payload_hash_included_when_checkbox_true(self) -> None: + """Hawk includes payload hash when includePayloadHash is true.""" + auth = { + "type": "hawk", + "hawk": [ + {"key": "authId", "value": "myid"}, + {"key": "authKey", "value": "mykey"}, + {"key": "algorithm", "value": "sha256"}, + {"key": "nonce", "value": "n"}, + {"key": "timestamp", "value": "0"}, + {"key": "includePayloadHash", "value": True}, + ], + } + _, hdr = apply_auth(auth, "https://x.io", {}, method="POST", body="data") + assert 'hash="' in hdr["Authorization"] + + def test_payload_hash_excluded_when_checkbox_false(self) -> None: + """Hawk omits payload hash when includePayloadHash is false.""" + auth = { + "type": "hawk", + "hawk": [ + {"key": "authId", "value": "myid"}, + {"key": "authKey", "value": "mykey"}, + {"key": "algorithm", "value": "sha256"}, + {"key": "nonce", "value": "n"}, + {"key": "timestamp", "value": "0"}, + {"key": "includePayloadHash", "value": "false"}, + ], + } + _, hdr = apply_auth(auth, "https://x.io", {}, method="POST", body="data") + assert 'hash="' not in hdr["Authorization"] + + def test_payload_hash_excluded_by_default(self) -> None: + """Hawk omits payload hash when checkbox entry is absent.""" + auth = { + "type": "hawk", + "hawk": [ + {"key": "authId", "value": "myid"}, + {"key": "authKey", "value": "mykey"}, + {"key": "algorithm", "value": "sha256"}, + {"key": "nonce", "value": "n"}, + {"key": "timestamp", "value": "0"}, + ], + } + _, hdr = apply_auth(auth, "https://x.io", {}, method="POST", body="data") + assert 'hash="' not in hdr["Authorization"] + + """AWS Signature V4 header generation.""" + + def test_adds_required_headers(self) -> None: + """Adds Authorization, x-amz-date, x-amz-content-sha256 headers.""" + auth = { + "type": "awsv4", + "awsv4": [ + {"key": "accessKey", "value": "AKID"}, + {"key": "secretKey", "value": "secret"}, + {"key": "region", "value": "us-east-1"}, + {"key": "service", "value": "s3"}, + ], + } + _, hdr = apply_auth(auth, "https://s3.amazonaws.com/bucket", {}, method="GET") + assert hdr["Authorization"].startswith("AWS4-HMAC-SHA256") + assert "Credential=AKID/" in hdr["Authorization"] + assert "x-amz-date" in hdr + assert "x-amz-content-sha256" in hdr + + def test_session_token(self) -> None: + """Session token adds x-amz-security-token header.""" + auth = { + "type": "awsv4", + "awsv4": [ + {"key": "accessKey", "value": "AKID"}, + {"key": "secretKey", "value": "secret"}, + {"key": "region", "value": "us-west-2"}, + {"key": "service", "value": "execute-api"}, + {"key": "sessionToken", "value": "tokval"}, + ], + } + _, hdr = apply_auth(auth, "https://x.io", {}, method="POST") + assert hdr["x-amz-security-token"] == "tokval" + + +class TestJwtAuth: + """JWT Bearer token generation (HMAC algorithms via stdlib).""" + + def test_hs256_header(self) -> None: + """Generates a valid HS256 JWT in Authorization header.""" + auth = { + "type": "jwt", + "jwt": [ + {"key": "algorithm", "value": "HS256"}, + {"key": "secret", "value": "mysecret"}, + {"key": "payload", "value": '{"sub":"123"}'}, + {"key": "headers", "value": "{}"}, + {"key": "isSecretBase64Encoded", "value": "false"}, + {"key": "addTokenTo", "value": "header"}, + {"key": "headerPrefix", "value": "Bearer"}, + ], + } + _, hdr = apply_auth(auth, "https://x.io", {}) + assert "Authorization" in hdr + token = hdr["Authorization"].removeprefix("Bearer ") + parts = token.split(".") + assert len(parts) == 3 + # Decode header to verify algorithm + header_json = base64.urlsafe_b64decode(parts[0] + "==") + header_data = json.loads(header_json) + assert header_data["alg"] == "HS256" + assert header_data["typ"] == "JWT" + + def test_query_param(self) -> None: + """JWT added as query param when configured.""" + auth = { + "type": "jwt", + "jwt": [ + {"key": "algorithm", "value": "HS256"}, + {"key": "secret", "value": "s"}, + {"key": "payload", "value": "{}"}, + {"key": "headers", "value": "{}"}, + {"key": "isSecretBase64Encoded", "value": "false"}, + {"key": "addTokenTo", "value": "queryParams"}, + {"key": "queryParamKey", "value": "jwt"}, + ], + } + url, _ = apply_auth(auth, "https://x.io", {}) + assert "jwt=" in url + + +class TestAsapAuth: + """ASAP (Atlassian) auth — requires RSA, may fall back gracefully.""" + + def test_returns_unchanged_without_pyjwt(self) -> None: + """Without PyJWT, ASAP (RS256) returns headers unchanged.""" + auth = { + "type": "asap", + "asap": [ + {"key": "issuer", "value": "iss"}, + {"key": "audience", "value": "aud"}, + {"key": "privateKey", "value": "not-a-real-key"}, + {"key": "kid", "value": "kid1"}, + {"key": "algorithm", "value": "RS256"}, + ], + } + _, hdr = apply_auth(auth, "https://x.io", {}) + # RS256 requires PyJWT — without it, no header is added + # (unless PyJWT is installed, in which case it would fail on the bad key) + + +class TestNtlmAuth: + """NTLM auth — pass-through only (no pre-computable header).""" + + def test_noop(self) -> None: + """NTLM does not modify headers or URL.""" + auth = { + "type": "ntlm", + "ntlm": [ + {"key": "username", "value": "user"}, + {"key": "password", "value": "pass"}, + ], + } + url, hdr = apply_auth(auth, "https://x.io", {"Existing": "h"}) + assert url == "https://x.io" + assert hdr == {"Existing": "h"} + + +class TestEdgeGridAuth: + """Akamai EdgeGrid signature generation.""" + + def test_produces_eg1_header(self) -> None: + """Generates Authorization: EG1-HMAC-SHA256 header.""" + auth = { + "type": "edgegrid", + "edgegrid": [ + {"key": "accessToken", "value": "at"}, + {"key": "clientToken", "value": "ct"}, + {"key": "clientSecret", "value": base64.b64encode(b"sec").decode()}, + {"key": "nonce", "value": "nonce1"}, + {"key": "timestamp", "value": "20240101T00:00:00+0000"}, + ], + } + _, hdr = apply_auth(auth, "https://x.io/path", {}, method="GET") + val = hdr["Authorization"] + assert val.startswith("EG1-HMAC-SHA256") + assert "client_token=ct" in val + assert "access_token=at" in val + assert "signature=" in val + + +class TestBooleanEntries: + """Verify boolean entry values are handled correctly.""" + + def test_true_bool(self) -> None: + """Python True in entry maps to 'true' string.""" + auth = { + "type": "oauth1", + "oauth1": [ + {"key": "consumerKey", "value": "ck"}, + {"key": "consumerSecret", "value": "cs"}, + {"key": "token", "value": ""}, + {"key": "tokenSecret", "value": ""}, + {"key": "signatureMethod", "value": "PLAINTEXT"}, + {"key": "timestamp", "value": "0"}, + {"key": "nonce", "value": "n"}, + {"key": "addParamsToHeader", "value": True}, + ], + } + _, hdr = apply_auth(auth, "https://x.io", {}, method="GET") + # addParamsToHeader=True means header mode + assert "Authorization" in hdr + + +class TestOAuth1BodyHash: + """OAuth 1.0 body hash and new fields.""" + + def test_include_body_hash_sha1(self) -> None: + """Body hash is included when includeBodyHash is true (HMAC-SHA1).""" + auth = { + "type": "oauth1", + "oauth1": [ + {"key": "consumerKey", "value": "ck"}, + {"key": "consumerSecret", "value": "cs"}, + {"key": "token", "value": "tok"}, + {"key": "tokenSecret", "value": "ts"}, + {"key": "signatureMethod", "value": "HMAC-SHA1"}, + {"key": "timestamp", "value": "1234567890"}, + {"key": "nonce", "value": "testnonce"}, + {"key": "addParamsToHeader", "value": "true"}, + {"key": "includeBodyHash", "value": True}, + ], + } + _, hdr = apply_auth(auth, "https://x.io/api", {}, method="POST", body="data=1") + val = hdr["Authorization"] + assert "oauth_body_hash" in val + + def test_body_hash_not_included_when_false(self) -> None: + """Body hash is NOT included when includeBodyHash is false.""" + auth = { + "type": "oauth1", + "oauth1": [ + {"key": "consumerKey", "value": "ck"}, + {"key": "consumerSecret", "value": "cs"}, + {"key": "token", "value": "tok"}, + {"key": "tokenSecret", "value": "ts"}, + {"key": "signatureMethod", "value": "HMAC-SHA1"}, + {"key": "timestamp", "value": "1234567890"}, + {"key": "nonce", "value": "testnonce"}, + {"key": "addParamsToHeader", "value": "true"}, + {"key": "includeBodyHash", "value": False}, + ], + } + _, hdr = apply_auth(auth, "https://x.io/api", {}, method="POST", body="data=1") + val = hdr["Authorization"] + assert "oauth_body_hash" not in val + + def test_callback_url_included(self) -> None: + """Callback URL adds oauth_callback to the signature.""" + auth = { + "type": "oauth1", + "oauth1": [ + {"key": "consumerKey", "value": "ck"}, + {"key": "consumerSecret", "value": "cs"}, + {"key": "token", "value": ""}, + {"key": "tokenSecret", "value": ""}, + {"key": "signatureMethod", "value": "HMAC-SHA1"}, + {"key": "timestamp", "value": "0"}, + {"key": "nonce", "value": "n"}, + {"key": "addParamsToHeader", "value": "true"}, + {"key": "callbackUrl", "value": "https://example.com/cb"}, + ], + } + _, hdr = apply_auth(auth, "https://x.io", {}, method="GET") + assert "oauth_callback" in hdr["Authorization"] + + def test_verifier_included(self) -> None: + """Verifier adds oauth_verifier to the signature.""" + auth = { + "type": "oauth1", + "oauth1": [ + {"key": "consumerKey", "value": "ck"}, + {"key": "consumerSecret", "value": "cs"}, + {"key": "token", "value": "tok"}, + {"key": "tokenSecret", "value": "ts"}, + {"key": "signatureMethod", "value": "HMAC-SHA1"}, + {"key": "timestamp", "value": "0"}, + {"key": "nonce", "value": "n"}, + {"key": "addParamsToHeader", "value": "true"}, + {"key": "verifier", "value": "v123"}, + ], + } + _, hdr = apply_auth(auth, "https://x.io", {}, method="GET") + assert "oauth_verifier" in hdr["Authorization"] + + def test_add_params_to_url(self) -> None: + """addParamsToHeader=false appends oauth params to URL query string.""" + auth = { + "type": "oauth1", + "oauth1": [ + {"key": "consumerKey", "value": "ck"}, + {"key": "consumerSecret", "value": "cs"}, + {"key": "token", "value": ""}, + {"key": "tokenSecret", "value": ""}, + {"key": "signatureMethod", "value": "PLAINTEXT"}, + {"key": "timestamp", "value": "0"}, + {"key": "nonce", "value": "n"}, + {"key": "addParamsToHeader", "value": "false"}, + ], + } + url, hdr = apply_auth(auth, "https://x.io/path", {}, method="GET") + assert "Authorization" not in hdr + assert "oauth_consumer_key=ck" in url + + def test_add_params_to_body(self) -> None: + """addParamsToHeader=body appends oauth params to URL (body mode).""" + auth = { + "type": "oauth1", + "oauth1": [ + {"key": "consumerKey", "value": "ck"}, + {"key": "consumerSecret", "value": "cs"}, + {"key": "token", "value": ""}, + {"key": "tokenSecret", "value": ""}, + {"key": "signatureMethod", "value": "PLAINTEXT"}, + {"key": "timestamp", "value": "0"}, + {"key": "nonce", "value": "n"}, + {"key": "addParamsToHeader", "value": "body"}, + ], + } + url, hdr = apply_auth(auth, "https://x.io/path", {}, method="POST") + assert "Authorization" not in hdr + assert "oauth_consumer_key=ck" in url diff --git a/tests/unit/services/http/test_oauth2_service.py b/tests/unit/services/http/test_oauth2_service.py new file mode 100644 index 0000000..f693441 --- /dev/null +++ b/tests/unit/services/http/test_oauth2_service.py @@ -0,0 +1,286 @@ +"""Tests for the OAuth 2.0 token exchange service.""" + +from __future__ import annotations + +from unittest.mock import MagicMock, patch + +import httpx + +from services.http.oauth2_service import ( + OAuth2Service, + _error_result, + _parse_redirect, + _post_token_request, +) + +# ------------------------------------------------------------------ +# Helper function tests +# ------------------------------------------------------------------ + + +class TestErrorResult: + """Verify _error_result helper builds correct TypedDict.""" + + def test_fields(self) -> None: + """All fields present with error message.""" + result = _error_result("boom") + assert result["error"] == "boom" + assert result["access_token"] == "" + assert result["token_type"] == "" + assert result["expires_in"] == 0 + assert result["refresh_token"] == "" + assert result["scope"] == "" + + +class TestParseRedirect: + """Verify _parse_redirect extracts URI and port.""" + + def test_with_explicit_url(self) -> None: + """Explicit callback URL returns correct port.""" + uri, port = _parse_redirect("http://localhost:9876/callback") + assert uri == "http://localhost:9876/callback" + assert port == 9876 + + def test_default_when_empty(self) -> None: + """Empty callback falls back to localhost:5000.""" + uri, port = _parse_redirect("") + assert port == 5000 + assert "localhost" in uri + + def test_no_port_defaults_to_5000(self) -> None: + """URL without port defaults to 5000.""" + _, port = _parse_redirect("http://localhost/callback") + assert port == 5000 + + +# ------------------------------------------------------------------ +# Direct grant type tests (mock httpx) +# ------------------------------------------------------------------ + + +def _mock_token_response( + *, + access_token: str = "test_access_token", + token_type: str = "Bearer", + expires_in: int = 3600, + refresh_token: str = "test_refresh", + scope: str = "read", +) -> MagicMock: + """Create a mock httpx.Response for a successful token exchange.""" + resp = MagicMock(spec=httpx.Response) + resp.status_code = 200 + resp.json.return_value = { + "access_token": access_token, + "token_type": token_type, + "expires_in": expires_in, + "refresh_token": refresh_token, + "scope": scope, + } + resp.raise_for_status = MagicMock() + return resp + + +class TestPasswordGrant: + """Password Credentials grant — direct POST.""" + + def test_missing_token_url(self) -> None: + """Returns error when token URL is empty.""" + result = OAuth2Service._password_credentials({"accessTokenUrl": ""}) + assert result["error"] + + @patch("services.http.oauth2_service.httpx.Client") + def test_successful_exchange(self, mock_client_cls: MagicMock) -> None: + """Successful password grant returns token.""" + mock_resp = _mock_token_response() + mock_client = MagicMock() + mock_client.__enter__ = MagicMock(return_value=mock_client) + mock_client.__exit__ = MagicMock(return_value=False) + mock_client.post.return_value = mock_resp + mock_client_cls.return_value = mock_client + + config = { + "grant_type": "password", + "accessTokenUrl": "https://auth.example.com/token", + "clientId": "my_client", + "clientSecret": "my_secret", + "username": "user", + "password": "pass", + "scope": "read", + "client_authentication": "header", + } + result = OAuth2Service.get_token(config) + + assert result["access_token"] == "test_access_token" + assert result["token_type"] == "Bearer" + assert result["error"] == "" + + # Verify client credentials in Basic Auth + call_kwargs = mock_client.post.call_args + assert call_kwargs.kwargs.get("auth") is not None + + @patch("services.http.oauth2_service.httpx.Client") + def test_body_auth(self, mock_client_cls: MagicMock) -> None: + """Body auth sends client_id/secret in POST body.""" + mock_resp = _mock_token_response() + mock_client = MagicMock() + mock_client.__enter__ = MagicMock(return_value=mock_client) + mock_client.__exit__ = MagicMock(return_value=False) + mock_client.post.return_value = mock_resp + mock_client_cls.return_value = mock_client + + config = { + "grant_type": "password", + "accessTokenUrl": "https://auth.example.com/token", + "clientId": "my_client", + "clientSecret": "my_secret", + "username": "user", + "password": "pass", + "client_authentication": "body", + } + result = OAuth2Service.get_token(config) + + assert result["access_token"] == "test_access_token" + call_kwargs = mock_client.post.call_args + post_data = call_kwargs.kwargs.get("data", call_kwargs[1].get("data", {})) + assert post_data.get("client_id") == "my_client" + + +class TestClientCredentialsGrant: + """Client Credentials grant — direct POST.""" + + def test_missing_token_url(self) -> None: + """Returns error when token URL is empty.""" + result = OAuth2Service._client_credentials({"accessTokenUrl": ""}) + assert result["error"] + + @patch("services.http.oauth2_service.httpx.Client") + def test_successful_exchange(self, mock_client_cls: MagicMock) -> None: + """Successful client_credentials grant returns token.""" + mock_resp = _mock_token_response(refresh_token="") + mock_client = MagicMock() + mock_client.__enter__ = MagicMock(return_value=mock_client) + mock_client.__exit__ = MagicMock(return_value=False) + mock_client.post.return_value = mock_resp + mock_client_cls.return_value = mock_client + + config = { + "grant_type": "client_credentials", + "accessTokenUrl": "https://auth.example.com/token", + "clientId": "service_client", + "clientSecret": "service_secret", + "scope": "api", + "client_authentication": "header", + } + result = OAuth2Service.get_token(config) + + assert result["access_token"] == "test_access_token" + assert result["error"] == "" + + @patch("services.http.oauth2_service.httpx.Client") + def test_scope_included(self, mock_client_cls: MagicMock) -> None: + """Scope is included in the POST data when provided.""" + mock_resp = _mock_token_response() + mock_client = MagicMock() + mock_client.__enter__ = MagicMock(return_value=mock_client) + mock_client.__exit__ = MagicMock(return_value=False) + mock_client.post.return_value = mock_resp + mock_client_cls.return_value = mock_client + + config = { + "grant_type": "client_credentials", + "accessTokenUrl": "https://auth.example.com/token", + "clientId": "cid", + "clientSecret": "csec", + "scope": "admin", + "client_authentication": "body", + } + OAuth2Service.get_token(config) + + call_kwargs = mock_client.post.call_args + post_data = call_kwargs.kwargs.get("data", call_kwargs[1].get("data", {})) + assert post_data.get("scope") == "admin" + + +class TestPostTokenRequest: + """Verify _post_token_request handles errors.""" + + @patch("services.http.oauth2_service.httpx.Client") + def test_http_error_with_json(self, mock_client_cls: MagicMock) -> None: + """HTTP error with JSON body extracts error_description.""" + error_resp = MagicMock(spec=httpx.Response) + error_resp.status_code = 400 + error_resp.json.return_value = { + "error": "invalid_grant", + "error_description": "Bad credentials", + } + exc = httpx.HTTPStatusError( + "400", + request=MagicMock(), + response=error_resp, + ) + error_resp.raise_for_status.side_effect = exc + + mock_client = MagicMock() + mock_client.__enter__ = MagicMock(return_value=mock_client) + mock_client.__exit__ = MagicMock(return_value=False) + mock_client.post.return_value = error_resp + mock_client_cls.return_value = mock_client + + result = _post_token_request( + "https://auth.example.com/token", + {"grant_type": "password"}, + "cid", + "csec", + "header", + ) + assert "Bad credentials" in result["error"] + + @patch("services.http.oauth2_service.httpx.Client") + def test_connection_error(self, mock_client_cls: MagicMock) -> None: + """Connection error returns error result.""" + mock_client = MagicMock() + mock_client.__enter__ = MagicMock(return_value=mock_client) + mock_client.__exit__ = MagicMock(return_value=False) + mock_client.post.side_effect = httpx.ConnectError("Connection refused") + mock_client_cls.return_value = mock_client + + result = _post_token_request( + "https://auth.example.com/token", + {"grant_type": "client_credentials"}, + "cid", + "csec", + "header", + ) + assert result["error"] + + +class TestGetTokenDispatch: + """Verify get_token dispatches to correct handler.""" + + def test_unknown_grant_type(self) -> None: + """Unknown grant type returns error.""" + result = OAuth2Service.get_token({"grant_type": "unknown"}) + assert "Unknown grant type" in result["error"] + + def test_authorization_code_missing_fields(self) -> None: + """Auth code grant with missing fields returns error.""" + result = OAuth2Service.get_token( + { + "grant_type": "authorization_code", + "authUrl": "", + "accessTokenUrl": "", + "clientId": "", + } + ) + assert result["error"] + + def test_implicit_missing_fields(self) -> None: + """Implicit grant with missing fields returns error.""" + result = OAuth2Service.get_token( + { + "grant_type": "implicit", + "authUrl": "", + "clientId": "", + } + ) + assert result["error"] diff --git a/tests/unit/services/http/test_snippet_compiled.py b/tests/unit/services/http/test_snippet_compiled.py new file mode 100644 index 0000000..db3df0e --- /dev/null +++ b/tests/unit/services/http/test_snippet_compiled.py @@ -0,0 +1,580 @@ +"""Tests for compiled / statically-typed language snippet generators.""" + +from __future__ import annotations + +from services.http.snippet_generator import SnippetGenerator + + +class TestGoNative: + """Verify Go (net/http) snippet generation.""" + + def test_basic_get(self) -> None: + """Snippet includes net/http import and request construction.""" + result = SnippetGenerator.generate( + "Go (net/http)", method="GET", url="https://api.example.com" + ) + assert "net/http" in result + assert "http.NewRequest" in result + assert "https://api.example.com" in result + + def test_with_headers(self) -> None: + """Snippet sets headers via req.Header.Set.""" + result = SnippetGenerator.generate( + "Go (net/http)", + method="POST", + url="https://api.example.com", + headers="Content-Type: application/json", + ) + assert "Header.Set" in result + assert "Content-Type" in result + + def test_with_body(self) -> None: + """Snippet includes strings.NewReader body.""" + result = SnippetGenerator.generate( + "Go (net/http)", + method="POST", + url="https://api.example.com", + body='{"key": "value"}', + ) + assert "strings.NewReader" in result + + def test_timeout_option(self) -> None: + """Timeout option sets Timeout on http.Client.""" + result = SnippetGenerator.generate( + "Go (net/http)", + method="GET", + url="https://example.com", + options={"request_timeout": 10}, + ) + assert "Timeout:" in result + assert "10" in result + + def test_no_redirect(self) -> None: + """Disabled redirects add CheckRedirect to client.""" + result = SnippetGenerator.generate( + "Go (net/http)", + method="GET", + url="https://example.com", + options={"follow_redirect": False}, + ) + assert "CheckRedirect" in result + assert "ErrUseLastResponse" in result + + +class TestRustReqwest: + """Verify Rust (reqwest) snippet generation.""" + + def test_basic_get(self) -> None: + """Snippet includes reqwest Client and method.""" + result = SnippetGenerator.generate( + "Rust (reqwest)", method="GET", url="https://api.example.com" + ) + assert "reqwest" in result + assert ".get(" in result + assert "https://api.example.com" in result + + def test_with_headers(self) -> None: + """Snippet chains .header() calls.""" + result = SnippetGenerator.generate( + "Rust (reqwest)", + method="POST", + url="https://api.example.com", + headers="Content-Type: application/json", + ) + assert ".header(" in result + assert "Content-Type" in result + + def test_with_body(self) -> None: + """Snippet chains .body() for POST.""" + result = SnippetGenerator.generate( + "Rust (reqwest)", + method="POST", + url="https://api.example.com", + body='{"key": "value"}', + ) + assert ".body(" in result + + def test_timeout_option(self) -> None: + """Timeout option chains .timeout().""" + result = SnippetGenerator.generate( + "Rust (reqwest)", + method="GET", + url="https://example.com", + options={"request_timeout": 30}, + ) + assert ".timeout(" in result + assert "30" in result + + def test_no_redirect(self) -> None: + """Disabled redirects add redirect policy.""" + result = SnippetGenerator.generate( + "Rust (reqwest)", + method="GET", + url="https://example.com", + options={"follow_redirect": False}, + ) + assert ".redirect(" in result + assert "Policy::none()" in result + + +class TestCLibcurl: + """Verify C (libcurl) snippet generation.""" + + def test_basic_get(self) -> None: + """Snippet includes curl_easy_init and URL.""" + result = SnippetGenerator.generate( + "C (libcurl)", method="GET", url="https://api.example.com" + ) + assert "curl_easy_init" in result + assert "CURLOPT_URL" in result + assert "https://api.example.com" in result + + def test_with_headers(self) -> None: + """Snippet adds headers via curl_slist_append.""" + result = SnippetGenerator.generate( + "C (libcurl)", + method="POST", + url="https://api.example.com", + headers="Content-Type: application/json", + ) + assert "curl_slist_append" in result + assert "Content-Type" in result + + def test_with_body(self) -> None: + """Snippet sets CURLOPT_POSTFIELDS.""" + result = SnippetGenerator.generate( + "C (libcurl)", + method="POST", + url="https://api.example.com", + body='{"key": "value"}', + ) + assert "CURLOPT_POSTFIELDS" in result + + def test_timeout_option(self) -> None: + """Timeout option sets CURLOPT_TIMEOUT.""" + result = SnippetGenerator.generate( + "C (libcurl)", + method="GET", + url="https://example.com", + options={"request_timeout": 20}, + ) + assert "CURLOPT_TIMEOUT" in result + + def test_follow_redirect(self) -> None: + """Default includes CURLOPT_FOLLOWLOCATION.""" + result = SnippetGenerator.generate( + "C (libcurl)", + method="GET", + url="https://example.com", + ) + assert "CURLOPT_FOLLOWLOCATION" in result + + def test_no_redirect(self) -> None: + """Disabled redirects omit CURLOPT_FOLLOWLOCATION.""" + result = SnippetGenerator.generate( + "C (libcurl)", + method="GET", + url="https://example.com", + options={"follow_redirect": False}, + ) + assert "CURLOPT_FOLLOWLOCATION" not in result + + def test_no_boilerplate(self) -> None: + """Boilerplate disabled omits includes and main wrapper.""" + result = SnippetGenerator.generate( + "C (libcurl)", + method="GET", + url="https://example.com", + options={"include_boilerplate": False}, + ) + assert "#include" not in result + assert "int main" not in result + assert "curl_easy_init" in result + + def test_boilerplate_default(self) -> None: + """Default includes #include and int main wrapper.""" + result = SnippetGenerator.generate( + "C (libcurl)", + method="GET", + url="https://example.com", + ) + assert "#include" in result + assert "int main" in result + + +class TestSwiftUrlsession: + """Verify Swift (URLSession) snippet generation.""" + + def test_basic_get(self) -> None: + """Snippet includes URLSession and URL construction.""" + result = SnippetGenerator.generate( + "Swift (URLSession)", method="GET", url="https://api.example.com" + ) + assert "URLSession" in result + assert "URL(string:" in result + assert "https://api.example.com" in result + + def test_with_headers(self) -> None: + """Snippet sets headers via setValue forHTTPHeaderField.""" + result = SnippetGenerator.generate( + "Swift (URLSession)", + method="POST", + url="https://api.example.com", + headers="Content-Type: application/json", + ) + assert "setValue" in result or "addValue" in result + assert "Content-Type" in result + + def test_with_body(self) -> None: + """Snippet sets httpBody on request.""" + result = SnippetGenerator.generate( + "Swift (URLSession)", + method="POST", + url="https://api.example.com", + body='{"key": "value"}', + ) + assert "httpBody" in result + + def test_timeout_option(self) -> None: + """Timeout option sets timeoutInterval.""" + result = SnippetGenerator.generate( + "Swift (URLSession)", + method="GET", + url="https://example.com", + options={"request_timeout": 15}, + ) + assert "timeoutInterval" in result or "timeout" in result.lower() + + def test_no_boilerplate(self) -> None: + """Boilerplate disabled omits import Foundation.""" + result = SnippetGenerator.generate( + "Swift (URLSession)", + method="GET", + url="https://example.com", + options={"include_boilerplate": False}, + ) + assert "import Foundation" not in result + assert "URLRequest" in result + + def test_boilerplate_default(self) -> None: + """Default includes import Foundation.""" + result = SnippetGenerator.generate( + "Swift (URLSession)", + method="GET", + url="https://example.com", + ) + assert "import Foundation" in result + + +class TestJavaOkhttp: + """Verify Java (OkHttp) snippet generation.""" + + def test_basic_get(self) -> None: + """Snippet includes OkHttpClient and Request.Builder.""" + result = SnippetGenerator.generate( + "Java (OkHttp)", method="GET", url="https://api.example.com" + ) + assert "OkHttpClient" in result + assert "Request.Builder" in result + assert "https://api.example.com" in result + + def test_with_headers(self) -> None: + """Snippet chains .addHeader() calls.""" + result = SnippetGenerator.generate( + "Java (OkHttp)", + method="POST", + url="https://api.example.com", + headers="Content-Type: application/json", + ) + assert "addHeader" in result or ".header(" in result + assert "Content-Type" in result + + def test_with_body(self) -> None: + """Snippet includes RequestBody.create.""" + result = SnippetGenerator.generate( + "Java (OkHttp)", + method="POST", + url="https://api.example.com", + body='{"key": "value"}', + ) + assert "RequestBody" in result + + def test_timeout_option(self) -> None: + """Timeout option sets connectTimeout or readTimeout.""" + result = SnippetGenerator.generate( + "Java (OkHttp)", + method="GET", + url="https://example.com", + options={"request_timeout": 10}, + ) + timeout_keywords = ["connectTimeout", "readTimeout", "callTimeout", "timeout"] + assert any(kw in result for kw in timeout_keywords) + + def test_no_redirect(self) -> None: + """Disabled redirects add followRedirects(false).""" + result = SnippetGenerator.generate( + "Java (OkHttp)", + method="GET", + url="https://example.com", + options={"follow_redirect": False}, + ) + assert "followRedirects(false)" in result + + def test_no_boilerplate(self) -> None: + """Boilerplate disabled omits imports and class wrapper.""" + result = SnippetGenerator.generate( + "Java (OkHttp)", + method="GET", + url="https://example.com", + options={"include_boilerplate": False}, + ) + assert "import " not in result + assert "OkHttpClient" in result + + def test_boilerplate_default(self) -> None: + """Default includes import statements.""" + result = SnippetGenerator.generate( + "Java (OkHttp)", + method="GET", + url="https://example.com", + ) + assert "import " in result + + +class TestKotlinOkhttp: + """Verify Kotlin (OkHttp) snippet generation.""" + + def test_basic_get(self) -> None: + """Snippet includes OkHttpClient and Request.Builder.""" + result = SnippetGenerator.generate( + "Kotlin (OkHttp)", method="GET", url="https://api.example.com" + ) + assert "OkHttpClient" in result + assert "Request.Builder" in result + assert "https://api.example.com" in result + + def test_with_headers(self) -> None: + """Snippet chains .addHeader() calls.""" + result = SnippetGenerator.generate( + "Kotlin (OkHttp)", + method="POST", + url="https://api.example.com", + headers="Content-Type: application/json", + ) + assert "addHeader" in result or ".header(" in result + assert "Content-Type" in result + + def test_with_body(self) -> None: + """Snippet includes RequestBody or toRequestBody.""" + result = SnippetGenerator.generate( + "Kotlin (OkHttp)", + method="POST", + url="https://api.example.com", + body='{"key": "value"}', + ) + assert "RequestBody" in result or "toRequestBody" in result + + def test_timeout_option(self) -> None: + """Timeout option configures OkHttp timeouts.""" + result = SnippetGenerator.generate( + "Kotlin (OkHttp)", + method="GET", + url="https://example.com", + options={"request_timeout": 10}, + ) + timeout_keywords = ["connectTimeout", "readTimeout", "callTimeout", "timeout"] + assert any(kw in result for kw in timeout_keywords) + + def test_no_redirect(self) -> None: + """Disabled redirects add followRedirects(false).""" + result = SnippetGenerator.generate( + "Kotlin (OkHttp)", + method="GET", + url="https://example.com", + options={"follow_redirect": False}, + ) + assert "followRedirects(false)" in result + + def test_no_boilerplate(self) -> None: + """Boilerplate disabled omits import statements.""" + result = SnippetGenerator.generate( + "Kotlin (OkHttp)", + method="GET", + url="https://example.com", + options={"include_boilerplate": False}, + ) + assert "import " not in result + assert "OkHttpClient" in result + + def test_boilerplate_default(self) -> None: + """Default includes import statements.""" + result = SnippetGenerator.generate( + "Kotlin (OkHttp)", + method="GET", + url="https://example.com", + ) + assert "import " in result + + +class TestCsharpHttpclient: + """Verify C# (HttpClient) snippet generation.""" + + def test_basic_get(self) -> None: + """Snippet includes HttpClient and method call.""" + result = SnippetGenerator.generate( + "C# (HttpClient)", method="GET", url="https://api.example.com" + ) + assert "HttpClient" in result + assert "https://api.example.com" in result + + def test_with_headers(self) -> None: + """Snippet sets request headers.""" + result = SnippetGenerator.generate( + "C# (HttpClient)", + method="POST", + url="https://api.example.com", + headers="Accept: application/json", + ) + assert "Accept" in result + assert "Headers.Add" in result + + def test_with_body(self) -> None: + """Snippet includes StringContent for body.""" + result = SnippetGenerator.generate( + "C# (HttpClient)", + method="POST", + url="https://api.example.com", + body='{"key": "value"}', + ) + assert "StringContent" in result + + def test_timeout_option(self) -> None: + """Timeout option sets client.Timeout.""" + result = SnippetGenerator.generate( + "C# (HttpClient)", + method="GET", + url="https://example.com", + options={"request_timeout": 15}, + ) + assert "Timeout" in result + + def test_no_redirect(self) -> None: + """Disabled redirects use HttpClientHandler with AllowAutoRedirect.""" + result = SnippetGenerator.generate( + "C# (HttpClient)", + method="GET", + url="https://example.com", + options={"follow_redirect": False}, + ) + assert "AllowAutoRedirect = false" in result + assert "HttpClientHandler" in result + + def test_no_boilerplate(self) -> None: + """Boilerplate disabled omits using directive.""" + result = SnippetGenerator.generate( + "C# (HttpClient)", + method="GET", + url="https://example.com", + options={"include_boilerplate": False}, + ) + assert "using System" not in result + assert "HttpClient" in result + + def test_boilerplate_default(self) -> None: + """Default includes using directive.""" + result = SnippetGenerator.generate( + "C# (HttpClient)", + method="GET", + url="https://example.com", + ) + assert "using System" in result + + +class TestCsharpRestsharp: + """Verify C# (RestSharp) snippet generation.""" + + def test_basic_get(self) -> None: + """Snippet includes RestClient and RestRequest.""" + result = SnippetGenerator.generate( + "C# (RestSharp)", method="GET", url="https://api.example.com" + ) + assert "RestClient" in result + assert "RestRequest" in result + assert "https://api.example.com" in result + + def test_with_headers(self) -> None: + """Snippet adds headers via AddHeader.""" + result = SnippetGenerator.generate( + "C# (RestSharp)", + method="POST", + url="https://api.example.com", + headers="Accept: application/json", + ) + assert "AddHeader" in result + assert "Accept" in result + + def test_with_body(self) -> None: + """Snippet includes AddStringBody for request body.""" + result = SnippetGenerator.generate( + "C# (RestSharp)", + method="POST", + url="https://api.example.com", + body='{"key": "value"}', + ) + assert "AddStringBody" in result + + def test_timeout_option(self) -> None: + """Timeout option sets MaxTimeout.""" + result = SnippetGenerator.generate( + "C# (RestSharp)", + method="GET", + url="https://example.com", + options={"request_timeout": 10}, + ) + assert "MaxTimeout" in result + assert "10000" in result + + def test_no_redirect(self) -> None: + """Disabled redirects set FollowRedirects = false.""" + result = SnippetGenerator.generate( + "C# (RestSharp)", + method="GET", + url="https://example.com", + options={"follow_redirect": False}, + ) + assert "FollowRedirects = false" in result + + def test_no_boilerplate(self) -> None: + """Boilerplate disabled omits using directive.""" + result = SnippetGenerator.generate( + "C# (RestSharp)", + method="GET", + url="https://example.com", + options={"include_boilerplate": False}, + ) + assert "using RestSharp" not in result + assert "RestClient" in result + + def test_boilerplate_default(self) -> None: + """Default includes using RestSharp directive.""" + result = SnippetGenerator.generate( + "C# (RestSharp)", + method="GET", + url="https://example.com", + ) + assert "using RestSharp" in result + + def test_method_mapping(self) -> None: + """HTTP methods map to RestSharp Method enum values.""" + for method, expected in [("POST", "Method.Post"), ("PUT", "Method.Put")]: + result = SnippetGenerator.generate( + "C# (RestSharp)", method=method, url="https://example.com" + ) + assert expected in result + + def test_execute_async(self) -> None: + """Snippet uses ExecuteAsync for the request call.""" + result = SnippetGenerator.generate( + "C# (RestSharp)", method="GET", url="https://example.com" + ) + assert "ExecuteAsync" in result diff --git a/tests/unit/services/http/test_snippet_dynamic.py b/tests/unit/services/http/test_snippet_dynamic.py new file mode 100644 index 0000000..cace6d8 --- /dev/null +++ b/tests/unit/services/http/test_snippet_dynamic.py @@ -0,0 +1,455 @@ +"""Tests for dynamic / interpreted language snippet generators.""" + +from __future__ import annotations + +from services.http.snippet_generator import SnippetGenerator + + +class TestPythonHttpClient: + """Verify Python http.client snippet generation.""" + + def test_basic_get(self) -> None: + """Snippet includes http.client import and connection.""" + result = SnippetGenerator.generate( + "Python (http.client)", method="GET", url="https://api.example.com/users" + ) + assert "import http.client" in result + assert "HTTPSConnection" in result + assert "api.example.com" in result + + def test_with_headers(self) -> None: + """Snippet includes headers dict.""" + result = SnippetGenerator.generate( + "Python (http.client)", + method="POST", + url="https://api.example.com/data", + headers="Content-Type: application/json", + ) + assert "Content-Type" in result + assert "headers" in result + + def test_with_body(self) -> None: + """Snippet passes body to conn.request.""" + result = SnippetGenerator.generate( + "Python (http.client)", + method="POST", + url="https://api.example.com/data", + body='{"key": "value"}', + ) + assert '{"key": "value"}' in result + + def test_timeout_option(self) -> None: + """Timeout option adds timeout parameter.""" + result = SnippetGenerator.generate( + "Python (http.client)", + method="GET", + url="https://example.com", + options={"request_timeout": 10}, + ) + assert "timeout=10" in result + + +class TestNodejsAxios: + """Verify Node.js Axios snippet generation.""" + + def test_basic_get(self) -> None: + """Snippet includes axios require and method.""" + result = SnippetGenerator.generate( + "NodeJS (Axios)", method="GET", url="https://api.example.com" + ) + assert "axios" in result + assert '"get"' in result + assert "https://api.example.com" in result + + def test_with_headers(self) -> None: + """Snippet includes headers in config.""" + result = SnippetGenerator.generate( + "NodeJS (Axios)", + method="POST", + url="https://api.example.com", + headers="Content-Type: application/json", + ) + assert "headers" in result + assert "Content-Type" in result + + def test_with_body(self) -> None: + """Snippet includes data in config.""" + result = SnippetGenerator.generate( + "NodeJS (Axios)", + method="POST", + url="https://api.example.com", + body='{"key": "value"}', + ) + assert "data:" in result + + def test_timeout_option(self) -> None: + """Timeout option adds timeout to config.""" + result = SnippetGenerator.generate( + "NodeJS (Axios)", + method="GET", + url="https://example.com", + options={"request_timeout": 5}, + ) + assert "timeout: 5000" in result + + def test_no_redirect(self) -> None: + """Redirects disabled adds maxRedirects: 0 to config.""" + result = SnippetGenerator.generate( + "NodeJS (Axios)", + method="GET", + url="https://example.com", + options={"follow_redirect": False}, + ) + assert "maxRedirects: 0" in result + + def test_async_await(self) -> None: + """Async/await option generates async function syntax.""" + result = SnippetGenerator.generate( + "NodeJS (Axios)", + method="GET", + url="https://example.com", + options={"async_await": True}, + ) + assert "async function" in result + assert "await axios" in result + assert ".then(" not in result + + def test_default_uses_then(self) -> None: + """Default mode uses .then() promise chains.""" + result = SnippetGenerator.generate( + "NodeJS (Axios)", + method="GET", + url="https://example.com", + ) + assert ".then(" in result + assert "async function" not in result + + +class TestNodejsNative: + """Verify Node.js native http/https snippet generation.""" + + def test_basic_get(self) -> None: + """Snippet includes https require and request options.""" + result = SnippetGenerator.generate( + "NodeJS (Native)", method="GET", url="https://api.example.com/users" + ) + assert 'require("https")' in result + assert '"GET"' in result + assert "api.example.com" in result + + def test_http_url(self) -> None: + """HTTP URLs use the http module.""" + result = SnippetGenerator.generate( + "NodeJS (Native)", method="GET", url="http://localhost:3000/api" + ) + assert 'require("http")' in result + + def test_with_body(self) -> None: + """Snippet uses req.write for body.""" + result = SnippetGenerator.generate( + "NodeJS (Native)", + method="POST", + url="https://api.example.com/data", + body='{"key": "value"}', + ) + assert "req.write" in result + + def test_with_headers(self) -> None: + """Snippet includes headers in options.""" + result = SnippetGenerator.generate( + "NodeJS (Native)", + method="POST", + url="https://api.example.com/data", + headers="Content-Type: application/json", + ) + assert "headers" in result + assert "Content-Type" in result + + def test_es6_features(self) -> None: + """ES6 features option uses import instead of require.""" + result = SnippetGenerator.generate( + "NodeJS (Native)", + method="GET", + url="https://api.example.com", + options={"es6_features": True}, + ) + assert "import https" in result + assert "require(" not in result + + def test_default_uses_require(self) -> None: + """Default mode uses require() syntax.""" + result = SnippetGenerator.generate( + "NodeJS (Native)", + method="GET", + url="https://api.example.com", + ) + assert 'require("https")' in result + assert "import " not in result + + +class TestJavascriptXhr: + """Verify JavaScript XMLHttpRequest snippet generation.""" + + def test_basic_get(self) -> None: + """Snippet includes XMLHttpRequest open and send.""" + result = SnippetGenerator.generate( + "JavaScript (XHR)", method="GET", url="https://api.example.com" + ) + assert "XMLHttpRequest" in result + assert "xhr.open" in result + assert '"GET"' in result + assert "xhr.send()" in result + + def test_with_headers(self) -> None: + """Snippet sets request headers.""" + result = SnippetGenerator.generate( + "JavaScript (XHR)", + method="POST", + url="https://api.example.com", + headers="Content-Type: application/json", + ) + assert "setRequestHeader" in result + assert "Content-Type" in result + + def test_with_body(self) -> None: + """Snippet passes body to xhr.send.""" + result = SnippetGenerator.generate( + "JavaScript (XHR)", + method="POST", + url="https://api.example.com", + body='{"key": "value"}', + ) + assert "xhr.send(" in result + assert "key" in result + + def test_timeout_option(self) -> None: + """Timeout option sets xhr.timeout.""" + result = SnippetGenerator.generate( + "JavaScript (XHR)", + method="GET", + url="https://example.com", + options={"request_timeout": 10}, + ) + assert "xhr.timeout = 10000" in result + + +class TestRubyNethttp: + """Verify Ruby Net::HTTP snippet generation.""" + + def test_basic_get(self) -> None: + """Snippet includes Net::HTTP and URI.parse.""" + result = SnippetGenerator.generate( + "Ruby (Net::HTTP)", method="GET", url="https://api.example.com" + ) + assert "net/http" in result + assert "URI.parse" in result + assert "Net::HTTP::Get" in result + + def test_https(self) -> None: + """HTTPS URLs enable SSL.""" + result = SnippetGenerator.generate( + "Ruby (Net::HTTP)", method="GET", url="https://secure.example.com" + ) + assert "use_ssl = true" in result + + def test_with_body(self) -> None: + """Snippet sets request body.""" + result = SnippetGenerator.generate( + "Ruby (Net::HTTP)", + method="POST", + url="https://api.example.com", + body='{"key": "value"}', + ) + assert "request.body" in result + + def test_timeout_option(self) -> None: + """Timeout option sets read_timeout.""" + result = SnippetGenerator.generate( + "Ruby (Net::HTTP)", + method="GET", + url="https://example.com", + options={"request_timeout": 15}, + ) + assert "read_timeout = 15" in result + + def test_no_redirect(self) -> None: + """Redirects disabled sets max_retries = 0.""" + result = SnippetGenerator.generate( + "Ruby (Net::HTTP)", + method="GET", + url="https://example.com", + options={"follow_redirect": False}, + ) + assert "max_retries = 0" in result + + +class TestPhpCurl: + """Verify PHP cURL snippet generation.""" + + def test_basic_get(self) -> None: + """Snippet includes curl_init and URL.""" + result = SnippetGenerator.generate( + "PHP (cURL)", method="GET", url="https://api.example.com" + ) + assert "curl_init" in result + assert "https://api.example.com" in result + assert "curl_exec" in result + + def test_with_headers(self) -> None: + """Snippet includes CURLOPT_HTTPHEADER.""" + result = SnippetGenerator.generate( + "PHP (cURL)", + method="POST", + url="https://api.example.com", + headers="Content-Type: application/json", + ) + assert "CURLOPT_HTTPHEADER" in result + assert "Content-Type" in result + + def test_with_body(self) -> None: + """Snippet includes CURLOPT_POSTFIELDS.""" + result = SnippetGenerator.generate( + "PHP (cURL)", + method="POST", + url="https://api.example.com", + body='{"key": "value"}', + ) + assert "CURLOPT_POSTFIELDS" in result + + def test_timeout_option(self) -> None: + """Timeout option adds CURLOPT_TIMEOUT.""" + result = SnippetGenerator.generate( + "PHP (cURL)", + method="GET", + url="https://example.com", + options={"request_timeout": 30}, + ) + assert "CURLOPT_TIMEOUT" in result + + def test_follow_redirect(self) -> None: + """Follow redirects adds CURLOPT_FOLLOWLOCATION.""" + result = SnippetGenerator.generate( + "PHP (cURL)", + method="GET", + url="https://example.com", + options={"follow_redirect": True}, + ) + assert "CURLOPT_FOLLOWLOCATION" in result + + def test_no_redirect(self) -> None: + """Redirects disabled omits CURLOPT_FOLLOWLOCATION.""" + result = SnippetGenerator.generate( + "PHP (cURL)", + method="GET", + url="https://example.com", + options={"follow_redirect": False}, + ) + assert "CURLOPT_FOLLOWLOCATION" not in result + + +class TestPhpGuzzle: + """Verify PHP Guzzle snippet generation.""" + + def test_basic_get(self) -> None: + """Snippet includes GuzzleHttp Client.""" + result = SnippetGenerator.generate( + "PHP (Guzzle)", method="GET", url="https://api.example.com" + ) + assert "GuzzleHttp" in result + assert "https://api.example.com" in result + + def test_with_headers(self) -> None: + """Snippet includes headers option.""" + result = SnippetGenerator.generate( + "PHP (Guzzle)", + method="POST", + url="https://api.example.com", + headers="Accept: application/json", + ) + assert "'headers'" in result + assert "Accept" in result + + def test_with_json_body(self) -> None: + """Snippet uses json option for JSON body.""" + result = SnippetGenerator.generate( + "PHP (Guzzle)", + method="POST", + url="https://api.example.com", + body='{"key": "value"}', + ) + assert "'json'" in result + + def test_timeout_option(self) -> None: + """Timeout option adds timeout to request options.""" + result = SnippetGenerator.generate( + "PHP (Guzzle)", + method="GET", + url="https://example.com", + options={"request_timeout": 20}, + ) + assert "'timeout' => 20" in result + + def test_no_redirect(self) -> None: + """Redirects disabled adds allow_redirects => false.""" + result = SnippetGenerator.generate( + "PHP (Guzzle)", + method="GET", + url="https://example.com", + options={"follow_redirect": False}, + ) + assert "'allow_redirects' => false" in result + + +class TestDartHttp: + """Verify Dart http snippet generation.""" + + def test_basic_get(self) -> None: + """Snippet includes http import and Uri.parse.""" + result = SnippetGenerator.generate( + "Dart (http)", method="GET", url="https://api.example.com" + ) + assert "http" in result + assert "Uri.parse" in result + assert "https://api.example.com" in result + + def test_with_headers(self) -> None: + """Snippet includes headers map.""" + result = SnippetGenerator.generate( + "Dart (http)", + method="POST", + url="https://api.example.com", + headers="Content-Type: application/json", + ) + assert "headers" in result + assert "Content-Type" in result + + def test_with_body(self) -> None: + """Snippet includes body parameter for POST.""" + result = SnippetGenerator.generate( + "Dart (http)", + method="POST", + url="https://api.example.com", + body='{"key": "value"}', + ) + assert "body:" in result + + def test_no_boilerplate(self) -> None: + """Boilerplate disabled omits import directive.""" + result = SnippetGenerator.generate( + "Dart (http)", + method="GET", + url="https://example.com", + options={"include_boilerplate": False}, + ) + assert "import " not in result + assert "http.get" in result + + def test_boilerplate_default(self) -> None: + """Default includes package:http import.""" + result = SnippetGenerator.generate( + "Dart (http)", + method="GET", + url="https://example.com", + ) + assert "import " in result + assert "package:http" in result diff --git a/tests/unit/services/http/test_snippet_generator.py b/tests/unit/services/http/test_snippet_generator.py index 054c910..4d062ce 100644 --- a/tests/unit/services/http/test_snippet_generator.py +++ b/tests/unit/services/http/test_snippet_generator.py @@ -2,40 +2,43 @@ from __future__ import annotations -from services.http.snippet_generator import SnippetGenerator +from services.http.snippet_generator import SnippetGenerator, SnippetOptions class TestSnippetGenerator: """Verify code snippet generation for various languages.""" def test_available_languages(self) -> None: - """Returns at least cURL, Python, and JavaScript.""" + """Returns all 23 supported language variants.""" langs = SnippetGenerator.available_languages() assert "cURL" in langs assert "Python (requests)" in langs assert "JavaScript (fetch)" in langs + assert "C# (RestSharp)" in langs + assert len(langs) == 23 def test_curl_basic_get(self) -> None: - """CURL snippet for a simple GET.""" - result = SnippetGenerator.curl(method="GET", url="https://api.example.com") + """CURL snippet for a simple GET with default long-form flags.""" + result = SnippetGenerator.generate("cURL", method="GET", url="https://api.example.com") assert "curl" in result - assert "-X" in result - assert "GET" in result + assert "--request GET" in result assert "https://api.example.com" in result def test_curl_with_headers(self) -> None: - """CURL snippet includes header flags.""" - result = SnippetGenerator.curl( + """CURL snippet includes header flags in long form by default.""" + result = SnippetGenerator.generate( + "cURL", method="POST", url="https://api.example.com", headers="Content-Type: application/json", ) - assert "-H" in result + assert "--header" in result assert "Content-Type: application/json" in result def test_curl_with_body(self) -> None: """CURL snippet includes -d flag for body.""" - result = SnippetGenerator.curl( + result = SnippetGenerator.generate( + "cURL", method="POST", url="https://api.example.com", body='{"key": "value"}', @@ -44,14 +47,17 @@ def test_curl_with_body(self) -> None: def test_python_requests_get(self) -> None: """Python snippet includes requests import and method call.""" - result = SnippetGenerator.python_requests(method="GET", url="https://api.example.com") + result = SnippetGenerator.generate( + "Python (requests)", method="GET", url="https://api.example.com" + ) assert "import requests" in result assert "requests.get" in result assert "https://api.example.com" in result def test_python_requests_with_json_body(self) -> None: """Python snippet uses json parameter for JSON bodies.""" - result = SnippetGenerator.python_requests( + result = SnippetGenerator.generate( + "Python (requests)", method="POST", url="https://api.example.com", body='{"key": "value"}', @@ -60,7 +66,9 @@ def test_python_requests_with_json_body(self) -> None: def test_javascript_fetch_basic(self) -> None: """JavaScript fetch snippet includes URL and method.""" - result = SnippetGenerator.javascript_fetch(method="GET", url="https://api.example.com") + result = SnippetGenerator.generate( + "JavaScript (fetch)", method="GET", url="https://api.example.com" + ) assert "fetch(" in result assert "https://api.example.com" in result assert '"GET"' in result @@ -82,5 +90,317 @@ def test_generate_dispatches_correctly(self) -> None: def test_generate_unknown_language(self) -> None: """Unknown language returns an unsupported message.""" - result = SnippetGenerator.generate("Ruby", method="GET", url="https://example.com") + result = SnippetGenerator.generate("COBOL", method="GET", url="https://example.com") assert "Unsupported" in result + + def test_curl_with_bearer_auth(self) -> None: + """CURL snippet includes Authorization header for bearer auth.""" + auth = {"type": "bearer", "bearer": [{"key": "token", "value": "abc123"}]} + result = SnippetGenerator.generate( + "cURL", + method="GET", + url="https://api.example.com", + auth=auth, + ) + assert "Authorization: Bearer abc123" in result + + def test_python_with_basic_auth(self) -> None: + """Python snippet includes Authorization header for basic auth.""" + auth = { + "type": "basic", + "basic": [ + {"key": "username", "value": "user"}, + {"key": "password", "value": "pass"}, + ], + } + result = SnippetGenerator.generate( + "Python (requests)", + method="GET", + url="https://api.example.com", + auth=auth, + ) + assert "Authorization" in result + assert "Basic" in result + + def test_curl_with_apikey_header(self) -> None: + """CURL snippet includes custom API key header.""" + auth = { + "type": "apikey", + "apikey": [ + {"key": "key", "value": "X-API-Key"}, + {"key": "value", "value": "secret"}, + {"key": "in", "value": "header"}, + ], + } + result = SnippetGenerator.generate( + "cURL", + method="GET", + url="https://api.example.com", + auth=auth, + ) + assert "X-API-Key: secret" in result + + def test_curl_with_apikey_query(self) -> None: + """CURL snippet appends API key to URL for query param auth.""" + auth = { + "type": "apikey", + "apikey": [ + {"key": "key", "value": "api_key"}, + {"key": "value", "value": "secret"}, + {"key": "in", "value": "query"}, + ], + } + result = SnippetGenerator.generate( + "cURL", + method="GET", + url="https://api.example.com", + auth=auth, + ) + assert "api_key=secret" in result + + +class TestSnippetOptions: + """Verify that SnippetOptions affect generated output.""" + + def test_indent_count_affects_output(self) -> None: + """Custom indent count changes indentation in output.""" + opts: SnippetOptions = {"indent_count": 4, "indent_type": "space"} + result = SnippetGenerator.generate( + "JavaScript (fetch)", + method="GET", + url="https://example.com", + options=opts, + ) + assert " method" in result + + def test_tab_indent(self) -> None: + """Tab indent type uses tab characters.""" + opts: SnippetOptions = {"indent_count": 1, "indent_type": "tab"} + result = SnippetGenerator.generate( + "JavaScript (fetch)", + method="GET", + url="https://example.com", + options=opts, + ) + assert "\tmethod" in result + + def test_trim_body(self) -> None: + """Trim body strips whitespace from request body.""" + opts: SnippetOptions = {"trim_body": True} + result = SnippetGenerator.generate( + "cURL", + method="POST", + url="https://example.com", + body=" hello ", + options=opts, + ) + assert "hello" in result + assert " hello " not in result + + def test_follow_redirect_curl(self) -> None: + """Follow redirect adds --location flag to cURL (long form default).""" + result_on = SnippetGenerator.generate( + "cURL", + method="GET", + url="https://example.com", + options={"follow_redirect": True}, + ) + assert "--location" in result_on + + result_off = SnippetGenerator.generate( + "cURL", + method="GET", + url="https://example.com", + options={"follow_redirect": False}, + ) + assert "--location" not in result_off + assert "-L" not in result_off + + def test_request_timeout_python(self) -> None: + """Request timeout adds timeout parameter to Python requests.""" + opts: SnippetOptions = {"request_timeout": 30} + result = SnippetGenerator.generate( + "Python (requests)", + method="GET", + url="https://example.com", + options=opts, + ) + assert "timeout=30" in result + + def test_follow_redirect_python_requests(self) -> None: + """Follow redirect=False adds allow_redirects=False.""" + result = SnippetGenerator.generate( + "Python (requests)", + method="GET", + url="https://example.com", + options={"follow_redirect": False}, + ) + assert "allow_redirects=False" in result + + def test_follow_redirect_default_python_requests(self) -> None: + """Default follow_redirect does not add allow_redirects.""" + result = SnippetGenerator.generate( + "Python (requests)", + method="GET", + url="https://example.com", + ) + assert "allow_redirects" not in result + + def test_follow_redirect_javascript_fetch(self) -> None: + """Disabled redirects add redirect: manual to fetch options.""" + result = SnippetGenerator.generate( + "JavaScript (fetch)", + method="GET", + url="https://example.com", + options={"follow_redirect": False}, + ) + assert '"manual"' in result + + def test_get_language_info(self) -> None: + """get_language_info returns correct metadata.""" + info = SnippetGenerator.get_language_info("cURL") + assert info is not None + assert info.lexer == "bash" + assert info.display_name == "cURL" + + def test_get_language_info_unknown(self) -> None: + """get_language_info returns None for unknown language.""" + assert SnippetGenerator.get_language_info("COBOL") is None + + def test_new_option_defaults(self) -> None: + """New options have correct defaults in resolve_options.""" + from services.http.snippet_generator.generator import resolve_options + + opts = resolve_options(None) + assert opts["include_boilerplate"] is True + assert opts["async_await"] is False + assert opts["es6_features"] is False + + def test_per_language_options_curl(self) -> None: + """Verify cURL has follow_redirect and request_timeout in options.""" + info = SnippetGenerator.get_language_info("cURL") + assert info is not None + assert "follow_redirect" in info.applicable_options + assert "request_timeout" in info.applicable_options + + def test_per_language_options_httpie(self) -> None: + """HTTPie has no indent options but has timeout and redirect.""" + info = SnippetGenerator.get_language_info("Shell (HTTPie)") + assert info is not None + assert "indent_count" not in info.applicable_options + assert "request_timeout" in info.applicable_options + assert "follow_redirect" in info.applicable_options + + def test_per_language_options_powershell(self) -> None: + """PowerShell has no indent options but has timeout and redirect.""" + info = SnippetGenerator.get_language_info("PowerShell (RestMethod)") + assert info is not None + assert "indent_count" not in info.applicable_options + assert "request_timeout" in info.applicable_options + assert "follow_redirect" in info.applicable_options + + def test_new_curl_option_defaults(self) -> None: + """New cURL-specific options have correct defaults.""" + from services.http.snippet_generator.generator import resolve_options + + opts = resolve_options(None) + assert opts["multiline"] is True + assert opts["long_form"] is True + assert opts["line_continuation"] == "\\\\" + assert opts["quote_type"] == "single" + assert opts["follow_original_method"] is False + assert opts["silent_mode"] is False + + def test_async_await_javascript_fetch(self) -> None: + """Async/await option wraps fetch in async function with await.""" + result = SnippetGenerator.generate( + "JavaScript (fetch)", + method="GET", + url="https://example.com", + options={"async_await": True}, + ) + assert "async " in result + assert "await fetch" in result + + def test_async_await_default_fetch(self) -> None: + """Default fetch uses .then() chain, not async/await.""" + result = SnippetGenerator.generate( + "JavaScript (fetch)", + method="GET", + url="https://example.com", + ) + assert ".then(" in result + assert "async " not in result + + def test_xhr_no_follow_redirect_option(self) -> None: + """XHR does not include follow_redirect in applicable options.""" + info = SnippetGenerator.get_language_info("JavaScript (XHR)") + assert info is not None + assert "follow_redirect" not in info.applicable_options + + def test_fetch_has_async_await_option(self) -> None: + """JavaScript Fetch includes async_await in applicable options.""" + info = SnippetGenerator.get_language_info("JavaScript (fetch)") + assert info is not None + assert "async_await" in info.applicable_options + + def test_dart_has_include_boilerplate_option(self) -> None: + """Dart http includes include_boilerplate in applicable options.""" + info = SnippetGenerator.get_language_info("Dart (http)") + assert info is not None + assert "include_boilerplate" in info.applicable_options + + def test_java_has_include_boilerplate_option(self) -> None: + """Java OkHttp includes include_boilerplate in applicable options.""" + info = SnippetGenerator.get_language_info("Java (OkHttp)") + assert info is not None + assert "include_boilerplate" in info.applicable_options + + def test_kotlin_has_include_boilerplate_option(self) -> None: + """Kotlin OkHttp includes include_boilerplate in applicable options.""" + info = SnippetGenerator.get_language_info("Kotlin (OkHttp)") + assert info is not None + assert "include_boilerplate" in info.applicable_options + + def test_csharp_restsharp_has_include_boilerplate(self) -> None: + """C# RestSharp includes include_boilerplate in applicable options.""" + info = SnippetGenerator.get_language_info("C# (RestSharp)") + assert info is not None + assert "include_boilerplate" in info.applicable_options + + def test_per_language_options_http_raw(self) -> None: + """HTTP raw only has trim_body.""" + info = SnippetGenerator.get_language_info("HTTP") + assert info is not None + assert info.applicable_options == ("trim_body",) + + def test_per_language_options_axios_has_async(self) -> None: + """NodeJS Axios has async_await in its options.""" + info = SnippetGenerator.get_language_info("NodeJS (Axios)") + assert info is not None + assert "async_await" in info.applicable_options + + def test_per_language_options_nodejs_native_has_es6(self) -> None: + """NodeJS Native has es6_features in its options.""" + info = SnippetGenerator.get_language_info("NodeJS (Native)") + assert info is not None + assert "es6_features" in info.applicable_options + + def test_per_language_options_c_has_boilerplate(self) -> None: + """C (libcurl) has include_boilerplate in its options.""" + info = SnippetGenerator.get_language_info("C (libcurl)") + assert info is not None + assert "include_boilerplate" in info.applicable_options + + def test_per_language_options_swift_has_boilerplate(self) -> None: + """Swift (URLSession) has include_boilerplate but not follow_redirect.""" + info = SnippetGenerator.get_language_info("Swift (URLSession)") + assert info is not None + assert "include_boilerplate" in info.applicable_options + assert "follow_redirect" not in info.applicable_options + + def test_per_language_options_python_http_client_no_redirect(self) -> None: + """Python http.client does not have follow_redirect.""" + info = SnippetGenerator.get_language_info("Python (http.client)") + assert info is not None + assert "follow_redirect" not in info.applicable_options diff --git a/tests/unit/services/http/test_snippet_shell.py b/tests/unit/services/http/test_snippet_shell.py new file mode 100644 index 0000000..46620bf --- /dev/null +++ b/tests/unit/services/http/test_snippet_shell.py @@ -0,0 +1,385 @@ +"""Tests for shell and CLI snippet generators.""" + +from __future__ import annotations + +from services.http.snippet_generator import SnippetGenerator + + +class TestCurlOptions: + """Verify new cURL-specific options.""" + + def test_short_form_flags(self) -> None: + """Long form disabled uses -X, -H, -d, -L short flags.""" + result = SnippetGenerator.generate( + "cURL", + method="POST", + url="https://example.com", + headers="Accept: application/json", + body='{"k": "v"}', + options={"long_form": False, "follow_redirect": True}, + ) + assert "-X POST" in result + assert "-H " in result + assert "-d " in result + assert "-L" in result + assert "--request" not in result + assert "--header" not in result + + def test_long_form_flags(self) -> None: + """Default long form uses --request, --header, --data, --location.""" + result = SnippetGenerator.generate( + "cURL", + method="POST", + url="https://example.com", + headers="Accept: application/json", + body='{"k": "v"}', + options={"long_form": True, "follow_redirect": True}, + ) + assert "--request POST" in result + assert "--header " in result + assert "--data " in result + assert "--location" in result + + def test_multiline_false_single_line(self) -> None: + """Multiline disabled produces a single-line command.""" + result = SnippetGenerator.generate( + "cURL", + method="GET", + url="https://example.com", + headers="Accept: text/html", + options={"multiline": False}, + ) + assert "\n" not in result + assert "curl " in result + + def test_multiline_true_uses_continuation(self) -> None: + """Multiline default splits across lines with continuation char.""" + result = SnippetGenerator.generate( + "cURL", + method="GET", + url="https://example.com", + headers="Accept: text/html", + options={"multiline": True}, + ) + assert "\\\n" in result + + def test_line_continuation_caret(self) -> None: + """Caret continuation character for Windows CMD.""" + result = SnippetGenerator.generate( + "cURL", + method="GET", + url="https://example.com", + headers="Accept: text/html", + options={"multiline": True, "line_continuation": "^"}, + ) + assert "^\n" in result + assert "\\\n" not in result + + def test_line_continuation_backtick(self) -> None: + """Backtick continuation character for PowerShell.""" + result = SnippetGenerator.generate( + "cURL", + method="GET", + url="https://example.com", + headers="Accept: text/html", + options={"multiline": True, "line_continuation": "`"}, + ) + assert "`\n" in result + + def test_quote_type_double(self) -> None: + """Double quote type wraps URL in double quotes.""" + result = SnippetGenerator.generate( + "cURL", + method="GET", + url="https://example.com", + options={"quote_type": "double"}, + ) + assert '"https://example.com"' in result + + def test_quote_type_single(self) -> None: + """Single quote type wraps URL in single quotes.""" + result = SnippetGenerator.generate( + "cURL", + method="GET", + url="https://example.com", + options={"quote_type": "single"}, + ) + assert "'https://example.com'" in result + + def test_follow_original_method(self) -> None: + """Follow original method adds --post301/302/303 flags.""" + result = SnippetGenerator.generate( + "cURL", + method="POST", + url="https://example.com", + options={"follow_original_method": True}, + ) + assert "--post301" in result + assert "--post302" in result + assert "--post303" in result + + def test_follow_original_method_off(self) -> None: + """Follow original method disabled omits --post30x flags.""" + result = SnippetGenerator.generate( + "cURL", + method="POST", + url="https://example.com", + options={"follow_original_method": False}, + ) + assert "--post301" not in result + + def test_silent_mode_long_form(self) -> None: + """Silent mode adds --silent when long form enabled.""" + result = SnippetGenerator.generate( + "cURL", + method="GET", + url="https://example.com", + options={"silent_mode": True, "long_form": True}, + ) + assert "--silent" in result + + def test_silent_mode_short_form(self) -> None: + """Silent mode adds -s when long form disabled.""" + result = SnippetGenerator.generate( + "cURL", + method="GET", + url="https://example.com", + options={"silent_mode": True, "long_form": False}, + ) + assert "-s" in result + assert "--silent" not in result + + def test_silent_mode_off(self) -> None: + """Silent mode disabled omits silent flag.""" + result = SnippetGenerator.generate( + "cURL", + method="GET", + url="https://example.com", + options={"silent_mode": False}, + ) + assert "--silent" not in result + assert "-s " not in result + + def test_curl_applicable_options(self) -> None: + """Verify cURL has the 6 new options plus trim_body, timeout, redirect.""" + info = SnippetGenerator.get_language_info("cURL") + assert info is not None + for opt in ( + "multiline", + "long_form", + "line_continuation", + "quote_type", + "follow_original_method", + "silent_mode", + "trim_body", + "request_timeout", + "follow_redirect", + ): + assert opt in info.applicable_options + # cURL does NOT have indent options + assert "indent_count" not in info.applicable_options + assert "indent_type" not in info.applicable_options + + +class TestHttpRaw: + """Verify raw HTTP snippet generation.""" + + def test_basic_get(self) -> None: + """Raw HTTP snippet contains method, path, and host.""" + result = SnippetGenerator.generate( + "HTTP", method="GET", url="https://api.example.com/users" + ) + assert "GET /users HTTP/1.1" in result + assert "Host: api.example.com" in result + + def test_with_headers(self) -> None: + """Raw HTTP snippet includes request headers.""" + result = SnippetGenerator.generate( + "HTTP", + method="POST", + url="https://api.example.com/data", + headers="Content-Type: application/json", + ) + assert "Content-Type: application/json" in result + + def test_with_body(self) -> None: + """Raw HTTP snippet includes body after blank line.""" + result = SnippetGenerator.generate( + "HTTP", + method="POST", + url="https://api.example.com/data", + body='{"key": "value"}', + ) + assert '{"key": "value"}' in result + assert "Content-Length:" in result + + def test_with_query_string(self) -> None: + """Raw HTTP snippet preserves query string in path.""" + result = SnippetGenerator.generate( + "HTTP", method="GET", url="https://api.example.com/search?q=test" + ) + assert "GET /search?q=test HTTP/1.1" in result + + +class TestShellWget: + """Verify wget snippet generation.""" + + def test_basic_get(self) -> None: + """Wget snippet includes method and URL.""" + result = SnippetGenerator.generate( + "Shell (wget)", method="GET", url="https://api.example.com" + ) + assert "wget" in result + assert "--method=GET" in result + assert "https://api.example.com" in result + + def test_with_headers(self) -> None: + """Wget snippet includes --header flags.""" + result = SnippetGenerator.generate( + "Shell (wget)", + method="POST", + url="https://api.example.com", + headers="Content-Type: application/json", + ) + assert "--header=" in result + assert "Content-Type: application/json" in result + + def test_with_body(self) -> None: + """Wget snippet includes --body-data flag.""" + result = SnippetGenerator.generate( + "Shell (wget)", + method="POST", + url="https://api.example.com", + body='{"key": "value"}', + ) + assert "--body-data=" in result + + def test_no_redirect(self) -> None: + """Wget snippet includes --max-redirect=0 when redirects disabled.""" + result = SnippetGenerator.generate( + "Shell (wget)", + method="GET", + url="https://example.com", + options={"follow_redirect": False}, + ) + assert "--max-redirect=0" in result + + def test_timeout_option(self) -> None: + """Timeout option adds --timeout flag.""" + result = SnippetGenerator.generate( + "Shell (wget)", + method="GET", + url="https://example.com", + options={"request_timeout": 30}, + ) + assert "--timeout=30" in result + + +class TestShellHttpie: + """Verify HTTPie snippet generation.""" + + def test_basic_get(self) -> None: + """HTTPie snippet includes http, method, and URL.""" + result = SnippetGenerator.generate( + "Shell (HTTPie)", method="GET", url="https://api.example.com" + ) + assert "http" in result + assert "GET" in result + assert "https://api.example.com" in result + + def test_with_headers(self) -> None: + """HTTPie snippet includes header key:value pairs.""" + result = SnippetGenerator.generate( + "Shell (HTTPie)", + method="POST", + url="https://api.example.com", + headers="Accept: application/json", + ) + assert "Accept" in result + + def test_no_redirect(self) -> None: + """HTTPie snippet includes --follow=false when redirects disabled.""" + result = SnippetGenerator.generate( + "Shell (HTTPie)", + method="GET", + url="https://example.com", + options={"follow_redirect": False}, + ) + assert "--follow=false" in result + + def test_timeout_option(self) -> None: + """Timeout option adds --timeout flag.""" + result = SnippetGenerator.generate( + "Shell (HTTPie)", + method="GET", + url="https://example.com", + options={"request_timeout": 15}, + ) + assert "--timeout=15" in result + + +class TestPowershellRestmethod: + """Verify PowerShell Invoke-RestMethod snippet generation.""" + + def test_basic_get(self) -> None: + """PowerShell snippet includes Invoke-RestMethod and URL.""" + result = SnippetGenerator.generate( + "PowerShell (RestMethod)", method="GET", url="https://api.example.com" + ) + assert "Invoke-RestMethod" in result + assert "https://api.example.com" in result + assert "-Method GET" in result + + def test_with_headers(self) -> None: + """PowerShell snippet includes $headers hashtable.""" + result = SnippetGenerator.generate( + "PowerShell (RestMethod)", + method="POST", + url="https://api.example.com", + headers="Content-Type: application/json", + ) + assert "$headers" in result + assert "-Headers $headers" in result + + def test_with_body(self) -> None: + """PowerShell snippet includes $body variable.""" + result = SnippetGenerator.generate( + "PowerShell (RestMethod)", + method="POST", + url="https://api.example.com", + body='{"key": "value"}', + ) + assert "$body" in result + assert "-Body $body" in result + + def test_with_auth(self) -> None: + """PowerShell snippet includes auth header.""" + auth = {"type": "bearer", "bearer": [{"key": "token", "value": "tok"}]} + result = SnippetGenerator.generate( + "PowerShell (RestMethod)", + method="GET", + url="https://api.example.com", + auth=auth, + ) + assert "Authorization" in result + assert "Bearer tok" in result + + def test_timeout_option(self) -> None: + """Timeout option adds -TimeoutSec parameter.""" + result = SnippetGenerator.generate( + "PowerShell (RestMethod)", + method="GET", + url="https://example.com", + options={"request_timeout": 20}, + ) + assert "-TimeoutSec 20" in result + + def test_no_redirect(self) -> None: + """Redirects disabled adds -MaximumRedirection 0.""" + result = SnippetGenerator.generate( + "PowerShell (RestMethod)", + method="GET", + url="https://example.com", + options={"follow_redirect": False}, + ) + assert "-MaximumRedirection 0" in result