|
8 | 8 | "fmt" |
9 | 9 | "io" |
10 | 10 | "strconv" |
| 11 | + "sync/atomic" |
11 | 12 | "testing" |
12 | 13 |
|
13 | 14 | iec "github.com/nspcc-dev/neofs-node/internal/ec" |
@@ -1136,3 +1137,184 @@ func parameterizeXHeaders(t testing.TB, p *Prm, ss []string) { |
1136 | 1137 |
|
1137 | 1138 | p.SetCommonParameters(cp) |
1138 | 1139 | } |
| 1140 | + |
| 1141 | +type failingReader struct { |
| 1142 | + data []byte |
| 1143 | + pos int |
| 1144 | + failAfter int |
| 1145 | + err error |
| 1146 | +} |
| 1147 | + |
| 1148 | +func (r *failingReader) Read(p []byte) (n int, err error) { |
| 1149 | + if r.pos >= len(r.data) { |
| 1150 | + return 0, io.EOF |
| 1151 | + } |
| 1152 | + |
| 1153 | + if r.pos >= r.failAfter && r.failAfter > 0 { |
| 1154 | + return 0, r.err |
| 1155 | + } |
| 1156 | + |
| 1157 | + n = copy(p, r.data[r.pos:]) |
| 1158 | + r.pos += n |
| 1159 | + return n, nil |
| 1160 | +} |
| 1161 | + |
| 1162 | +func (r *failingReader) Close() error { |
| 1163 | + return nil |
| 1164 | +} |
| 1165 | + |
| 1166 | +type trackingWriter struct { |
| 1167 | + writeHeaderCount atomic.Int32 |
| 1168 | + writeChunkCount atomic.Int32 |
| 1169 | + failAfterChunks int32 |
| 1170 | + err error |
| 1171 | +} |
| 1172 | + |
| 1173 | +func (w *trackingWriter) WriteHeader(*object.Object) error { |
| 1174 | + w.writeHeaderCount.Add(1) |
| 1175 | + return nil |
| 1176 | +} |
| 1177 | + |
| 1178 | +func (w *trackingWriter) WriteChunk([]byte) error { |
| 1179 | + count := w.writeChunkCount.Add(1) |
| 1180 | + |
| 1181 | + if w.failAfterChunks > 0 && count == w.failAfterChunks { |
| 1182 | + return w.err |
| 1183 | + } |
| 1184 | + return nil |
| 1185 | +} |
| 1186 | + |
| 1187 | +type testStorageWithFailingReader struct { |
| 1188 | + unimplementedLocalStorage |
| 1189 | + obj *object.Object |
| 1190 | + failAfter int |
| 1191 | + err error |
| 1192 | +} |
| 1193 | + |
| 1194 | +func (s *testStorageWithFailingReader) get(*execCtx) (*object.Object, io.ReadCloser, error) { |
| 1195 | + if s.obj == nil { |
| 1196 | + return nil, nil, errors.New("object not found") |
| 1197 | + } |
| 1198 | + |
| 1199 | + payload := s.obj.Payload() |
| 1200 | + reader := &failingReader{ |
| 1201 | + data: payload, |
| 1202 | + failAfter: s.failAfter, |
| 1203 | + err: s.err, |
| 1204 | + } |
| 1205 | + |
| 1206 | + objWithoutPayload := s.obj.CutPayload() |
| 1207 | + objWithoutPayload.SetPayloadSize(s.obj.PayloadSize()) |
| 1208 | + return objWithoutPayload, reader, nil |
| 1209 | +} |
| 1210 | + |
| 1211 | +func (s *testStorageWithFailingReader) Head(oid.Address, bool) (*object.Object, error) { |
| 1212 | + if s.obj == nil { |
| 1213 | + return nil, errors.New("object not found") |
| 1214 | + } |
| 1215 | + return s.obj.CutPayload(), nil |
| 1216 | +} |
| 1217 | + |
| 1218 | +func TestDoubleWriteHeaderOnPayloadReadFailure(t *testing.T) { |
| 1219 | + ctx := context.Background() |
| 1220 | + addr := oidtest.Address() |
| 1221 | + |
| 1222 | + payloadSize := 1024 * 1024 // 1MB > chunk (256KB) |
| 1223 | + payload := make([]byte, payloadSize) |
| 1224 | + _, _ = rand.Read(payload) |
| 1225 | + |
| 1226 | + obj := generateObject(addr, nil, payload) |
| 1227 | + |
| 1228 | + readErr := errors.New("simulated payload read error") |
| 1229 | + storage := &testStorageWithFailingReader{ |
| 1230 | + obj: obj, |
| 1231 | + failAfter: 300 * 1024, // > chunk |
| 1232 | + err: readErr, |
| 1233 | + } |
| 1234 | + |
| 1235 | + anyNodeLists, nodeStrs := testNodeMatrix(t, []int{1}) |
| 1236 | + |
| 1237 | + clientCache := &testClientCache{ |
| 1238 | + clients: make(map[string]*testClient), |
| 1239 | + } |
| 1240 | + remoteClient := newTestClient() |
| 1241 | + remoteClient.addResult(addr, obj, nil) |
| 1242 | + clientCache.clients[nodeStrs[0][0]] = remoteClient |
| 1243 | + |
| 1244 | + svc := &Service{cfg: new(cfg)} |
| 1245 | + svc.log = zaptest.NewLogger(t) |
| 1246 | + svc.localObjects = storage |
| 1247 | + svc.localStorage = storage |
| 1248 | + svc.clientCache = clientCache |
| 1249 | + svc.neoFSNet = &testNeoFS{ |
| 1250 | + vectors: map[oid.Address][][]netmap.NodeInfo{ |
| 1251 | + addr: anyNodeLists, |
| 1252 | + }, |
| 1253 | + } |
| 1254 | + |
| 1255 | + writer := &trackingWriter{} |
| 1256 | + |
| 1257 | + var prm Prm |
| 1258 | + prm.SetObjectWriter(writer) |
| 1259 | + prm.WithAddress(addr) |
| 1260 | + prm.common = new(util.CommonPrm) |
| 1261 | + |
| 1262 | + err := svc.Get(ctx, prm) |
| 1263 | + require.ErrorIs(t, err, readErr) |
| 1264 | + |
| 1265 | + t.Logf("WriteHeader called: %d times", writer.writeHeaderCount.Load()) |
| 1266 | + t.Logf("WriteChunk called: %d times", writer.writeChunkCount.Load()) |
| 1267 | + require.EqualValues(t, 1, writer.writeHeaderCount.Load()) |
| 1268 | +} |
| 1269 | + |
| 1270 | +func TestDoubleWriteHeaderOnChunkWriteFailure(t *testing.T) { |
| 1271 | + ctx := context.Background() |
| 1272 | + addr := oidtest.Address() |
| 1273 | + |
| 1274 | + payloadSize := 1024 * 1024 // 1MB > chunk (256KB) |
| 1275 | + payload := make([]byte, payloadSize) |
| 1276 | + _, _ = rand.Read(payload) |
| 1277 | + |
| 1278 | + obj := generateObject(addr, nil, payload) |
| 1279 | + |
| 1280 | + storage := newTestStorage() |
| 1281 | + storage.addPhy(addr, obj) |
| 1282 | + |
| 1283 | + anyNodeLists, nodeStrs := testNodeMatrix(t, []int{1}) |
| 1284 | + |
| 1285 | + clientCache := &testClientCache{ |
| 1286 | + clients: make(map[string]*testClient), |
| 1287 | + } |
| 1288 | + remoteClient := newTestClient() |
| 1289 | + remoteClient.addResult(addr, obj, nil) |
| 1290 | + clientCache.clients[nodeStrs[0][0]] = remoteClient |
| 1291 | + |
| 1292 | + svc := &Service{cfg: new(cfg)} |
| 1293 | + svc.log = zaptest.NewLogger(t) |
| 1294 | + svc.localObjects = storage |
| 1295 | + svc.localStorage = storage |
| 1296 | + svc.clientCache = clientCache |
| 1297 | + svc.neoFSNet = &testNeoFS{ |
| 1298 | + vectors: map[oid.Address][][]netmap.NodeInfo{ |
| 1299 | + addr: anyNodeLists, |
| 1300 | + }, |
| 1301 | + } |
| 1302 | + |
| 1303 | + writeErr := errors.New("simulated chunk write error") |
| 1304 | + writer := &trackingWriter{ |
| 1305 | + failAfterChunks: 2, |
| 1306 | + err: writeErr, |
| 1307 | + } |
| 1308 | + |
| 1309 | + var prm Prm |
| 1310 | + prm.SetObjectWriter(writer) |
| 1311 | + prm.WithAddress(addr) |
| 1312 | + prm.common = new(util.CommonPrm) |
| 1313 | + |
| 1314 | + err := svc.Get(ctx, prm) |
| 1315 | + require.ErrorIs(t, err, writeErr) |
| 1316 | + |
| 1317 | + t.Logf("WriteHeader called: %d times", writer.writeHeaderCount.Load()) |
| 1318 | + t.Logf("WriteChunk called: %d times", writer.writeChunkCount.Load()) |
| 1319 | + require.EqualValues(t, 1, writer.writeHeaderCount.Load()) |
| 1320 | +} |
0 commit comments