|
9 | 9 | def custom_collate(batch): |
10 | 10 | """Custom collate function to handle a complex data structure. |
11 | 11 |
|
12 | | - Each sample is a dictionary containing numpy arrays and another dictionary |
13 | | - with sparse matrices. Since we're using a batch size of 1, this function |
14 | | - simplifies the handling of these structures. |
15 | | -
|
16 | 12 | Args: |
17 | | - batch: A list of samples, where each sample is the complex data structure |
18 | | - described above. |
| 13 | + batch: A list of samples, where each sample is the complex data structure. |
19 | 14 |
|
20 | 15 | Returns: |
21 | 16 | Processed batch ready for model input. |
22 | 17 | """ |
23 | | - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
24 | 18 | # Unpack the single sample from the batch |
25 | 19 | sample = batch[0] |
26 | | - # Initialize a new dictionary to store the processed sample |
27 | 20 | processed_sample = {} |
28 | 21 |
|
29 | 22 | for key, value in sample.items(): |
30 | 23 | if isinstance(value, np.ndarray): |
31 | 24 | # Convert numpy arrays to tensors |
32 | | - processed_sample[key] = torch.tensor(value).to(device) |
| 25 | + processed_sample[key] = torch.tensor(value) |
33 | 26 | elif isinstance(value, dict): |
34 | | - # For the nested dictionary, we assume it contains sparse matrices |
35 | | - # and pass it through directly without modifications |
| 27 | + # For the nested dictionary, directly pass it through without GPU operations |
36 | 28 | processed_sample[key] = { |
37 | | - subkey: subvalue.to(device) for subkey, subvalue in value.items() |
| 29 | + subkey: subvalue for subkey, subvalue in value.items() |
38 | 30 | } |
39 | 31 | else: |
40 | 32 | # Directly pass through any other types of values |
|
0 commit comments