Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions src/core/tasks/scheduled/impl/huggingface/operator.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from itertools import count

Check warning on line 1 in src/core/tasks/scheduled/impl/huggingface/operator.py

View workflow job for this annotation

GitHub Actions / flake8

[flake8] src/core/tasks/scheduled/impl/huggingface/operator.py#L1 <100>

Missing docstring in public module
Raw output
./src/core/tasks/scheduled/impl/huggingface/operator.py:1:1: D100 Missing docstring in public module

from src.core.tasks.scheduled.templates.operator import ScheduledTaskOperatorBase
from src.db.client.async_ import AsyncDatabaseClient
Expand Down Expand Up @@ -30,7 +31,10 @@

# Otherwise, push to huggingface
run_dt = await self.adb_client.get_current_database_time()
outputs = await self.adb_client.get_data_sources_raw_for_huggingface()
self.hf_client.push_data_sources_raw_to_hub(outputs)
for idx in count(start=1):
outputs = await self.adb_client.get_data_sources_raw_for_huggingface(page=idx)
if len(outputs) == 0:
break
self.hf_client.push_data_sources_raw_to_hub(outputs, idx=idx)

await self.adb_client.set_hugging_face_upload_state(run_dt.replace(tzinfo=None))
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from src.core.tasks.scheduled.impl.huggingface.queries.get.convert import convert_url_status_to_relevant, \
convert_fine_to_coarse_record_type
from src.core.tasks.scheduled.impl.huggingface.queries.get.model import GetForLoadingToHuggingFaceOutput
from src.db.client.helpers import add_standard_limit_and_offset
from src.db.models.impl.url.html.compressed.sqlalchemy import URLCompressedHTML
from src.db.models.impl.url.core.sqlalchemy import URL
from src.db.queries.base.builder import QueryBuilderBase
Expand All @@ -13,6 +14,10 @@

class GetForLoadingToHuggingFaceQueryBuilder(QueryBuilderBase):

def __init__(self, page: int):

Check warning on line 17 in src/core/tasks/scheduled/impl/huggingface/queries/get/core.py

View workflow job for this annotation

GitHub Actions / flake8

[flake8] src/core/tasks/scheduled/impl/huggingface/queries/get/core.py#L17 <107>

Missing docstring in __init__
Raw output
./src/core/tasks/scheduled/impl/huggingface/queries/get/core.py:17:1: D107 Missing docstring in __init__
super().__init__()
self.page = page


async def run(self, session: AsyncSession) -> list[GetForLoadingToHuggingFaceOutput]:
label_url_id = 'url_id'
Expand Down Expand Up @@ -42,6 +47,7 @@
])
)
)
query = add_standard_limit_and_offset(page=self.page, statement=query)
db_results = await sh.mappings(
session=session,
query=query
Expand Down
4 changes: 2 additions & 2 deletions src/db/client/async_.py
Original file line number Diff line number Diff line change
Expand Up @@ -1463,9 +1463,9 @@
)
session.add(compressed_html)

async def get_data_sources_raw_for_huggingface(self) -> list[GetForLoadingToHuggingFaceOutput]:
async def get_data_sources_raw_for_huggingface(self, page: int) -> list[GetForLoadingToHuggingFaceOutput]:

Check warning on line 1466 in src/db/client/async_.py

View workflow job for this annotation

GitHub Actions / flake8

[flake8] src/db/client/async_.py#L1466 <102>

Missing docstring in public method
Raw output
./src/db/client/async_.py:1466:1: D102 Missing docstring in public method
return await self.run_query_builder(
GetForLoadingToHuggingFaceQueryBuilder()
GetForLoadingToHuggingFaceQueryBuilder(page)
)

async def set_hugging_face_upload_state(self, dt: datetime) -> None:
Expand Down
26 changes: 22 additions & 4 deletions src/external/huggingface/hub/client.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@

from datasets import Dataset
from huggingface_hub import HfApi

from src.external.huggingface.hub.constants import DATA_SOURCES_RAW_REPO_ID
from src.external.huggingface.hub.format import format_as_huggingface_dataset
Expand All @@ -10,22 +11,39 @@

def __init__(self, token: str):
self.token = token
self.api = HfApi(token=token)

def _push_dataset_to_hub(self, repo_id: str, dataset: Dataset) -> None:
def _push_dataset_to_hub(
self,
repo_id: str,
dataset: Dataset,
idx: int
) -> None:
"""
Modifies:
- repository on Hugging Face, identified by `repo_id`
"""
dataset.push_to_hub(repo_id=repo_id, token=self.token)
dataset.to_parquet(f"part_{idx}.parquet")
self.api.upload_file(
path_or_fileobj=f"part_{idx}.parquet",
path_in_repo=f"data/part_{idx}.parquet",
repo_id=repo_id,
repo_type="dataset",
)

def push_data_sources_raw_to_hub(
self,
outputs: list[GetForLoadingToHuggingFaceOutput]
outputs: list[GetForLoadingToHuggingFaceOutput],
idx: int
) -> None:
"""
Modifies:
- repository on Hugging Face, identified by `DATA_SOURCES_RAW_REPO_ID`
"""
dataset = format_as_huggingface_dataset(outputs)
print(dataset)
self._push_dataset_to_hub(repo_id=DATA_SOURCES_RAW_REPO_ID, dataset=dataset)
self._push_dataset_to_hub(
repo_id=DATA_SOURCES_RAW_REPO_ID,
dataset=dataset,
idx=idx
)

Check warning on line 49 in src/external/huggingface/hub/client.py

View workflow job for this annotation

GitHub Actions / flake8

[flake8] src/external/huggingface/hub/client.py#L49 <292>

no newline at end of file
Raw output
./src/external/huggingface/hub/client.py:49:10: W292 no newline at end of file