Skip to content

Commit

Permalink
Fix: ml/engine/utils/FileUtils casts long file length to int incorrec…
Browse files Browse the repository at this point in the history
…tly (opensearch-project#3198)

* Use longs when splitting model zip file

Signed-off-by: Max Lepikhin <[email protected]>

* add test

Signed-off-by: Max Lepikhin <[email protected]>

* spotless

Signed-off-by: Max Lepikhin <[email protected]>

* clean up test

Signed-off-by: Max Lepikhin <[email protected]>

---------

Signed-off-by: Max Lepikhin <[email protected]>
Signed-off-by: tkykenmt <[email protected]>
  • Loading branch information
maxlepikhin authored and tkykenmt committed Dec 15, 2024
1 parent cea0cc3 commit f8a479a
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -45,16 +45,16 @@ public class FileUtils {
* @throws IOException
*/
public static List<String> splitFileIntoChunks(File file, Path outputPath, int chunkSize) throws IOException {
int fileSize = (int) file.length();
long fileSize = file.length();
ArrayList<String> nameList = new ArrayList<>();
try (InputStream inStream = new BufferedInputStream(new FileInputStream(file))) {
int numberOfChunk = 0;
int totalBytesRead = 0;
long totalBytesRead = 0;
while (totalBytesRead < fileSize) {
String partName = numberOfChunk + "";
int bytesRemaining = fileSize - totalBytesRead;
long bytesRemaining = fileSize - totalBytesRead;
if (bytesRemaining < chunkSize) {
chunkSize = bytesRemaining;
chunkSize = (int) bytesRemaining;
}
byte[] temporary = new byte[chunkSize];
int bytesRead = inStream.read(temporary, 0, chunkSize);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.engine.utils;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;

import java.io.File;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.Arrays;
import java.util.List;
import java.util.Random;

import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;

public class FileUtilsTest {
private TemporaryFolder tempDir;

@Before
public void setUp() throws Exception {
tempDir = new TemporaryFolder();
tempDir.create();
}

@After
public void tearDown() {
if (tempDir != null) {
tempDir.delete();
}
}

@Test
public void testSplitFileIntoChunks() throws Exception {
// Write file.
Random random = new Random();
File file = tempDir.newFile("large_file");
byte[] data = new byte[1017];
random.nextBytes(data);
Files.write(file.toPath(), data);

// Split file into chunks.
int chunkSize = 325;
List<String> chunkPaths = FileUtils.splitFileIntoChunks(file, tempDir.newFolder().toPath(), chunkSize);

// Verify.
int currentPosition = 0;
for (String chunkPath : chunkPaths) {
byte[] chunk = Files.readAllBytes(Path.of(chunkPath));
assertTrue("Chunk size", currentPosition + chunk.length <= data.length);
Assert.assertArrayEquals(Arrays.copyOfRange(data, currentPosition, currentPosition + chunk.length), chunk);
currentPosition += chunk.length;
}
assertEquals(currentPosition, data.length);
}
}

0 comments on commit f8a479a

Please sign in to comment.