Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@ import android.Manifest
import android.graphics.PixelFormat
import android.os.Bundle
import android.util.Log
import android.view.Choreographer
import android.view.ContextThemeWrapper
import android.view.View
import android.view.ViewGroup
import android.view.WindowManager
import androidx.test.platform.app.InstrumentationRegistry
import androidx.test.rule.GrantPermissionRule
Expand All @@ -16,7 +18,8 @@ import com.facebook.testing.screenshot.ViewHelpers
import org.junit.Assert.assertTrue
import org.junit.Rule
import org.junit.Test
import java.io.File
import java.util.concurrent.CountDownLatch
import java.util.concurrent.TimeUnit

/**
* Base screenshot test that automatically discovers and tests all Storybook stories.
Expand All @@ -28,17 +31,23 @@ import java.io.File
* class StoryScreenshotTest : BaseStoryScreenshotTest()
* ```
*
* This test automatically bootstraps the story manifest if it doesn't exist,
* then creates a screenshot for each story. No manual test methods needed -
* just add stories to Storybook and they get tested automatically.
* Experiment: js-driven-story-loop
* ---------------------------------
* JS drives the entire story sequence. StoryRenderer renders all stories one at a
* time in a for loop, calling notifyStoryReady(storyId) after each render and
* waiting for the Promise to resolve before moving to the next story.
*
* The test thread loops reactively: wait for notifyStoryReady → take screenshot →
* resolve the Promise → repeat until allStoriesDone().
*
* No events, no isBlockingSynchronousMethod, no manifest pre-loading step.
*/
abstract class BaseStoryScreenshotTest {

companion object {
private const val TAG = "BaseStoryScreenshotTest"
private const val DEFAULT_LOAD_TIMEOUT_MS = 5000L
private const val DEFAULT_BOOTSTRAP_TIMEOUT_MS = 10000L
private const val BOOTSTRAP_STORY_NAME = "__bootstrap__"
private const val DEFAULT_TOTAL_TIMEOUT_MS = 300_000L // 5 minutes for all stories

private const val SCREEN_WIDTH_PX = 1080
private const val SCREEN_HEIGHT_PX = 1920
Expand All @@ -58,118 +67,84 @@ abstract class BaseStoryScreenshotTest {
open fun getMainComponentName(): String = "StoryRenderer"

/**
* Override to customize the React Native load timeout per story.
* Override to customize the per-story screenshot timeout.
* Default is 5000ms.
*/
open fun getLoadTimeoutMs(): Long = DEFAULT_LOAD_TIMEOUT_MS

/**
* Override to customize the timeout for manifest bootstrap.
* Default is 10000ms.
*/
open fun getBootstrapTimeoutMs(): Long = DEFAULT_BOOTSTRAP_TIMEOUT_MS

/**
* Override to filter which stories should be screenshotted.
* Return true to include the story, false to skip it.
* Default includes all stories.
*/
open fun shouldScreenshotStory(storyInfo: StoryInfo): Boolean = true

/**
* Screenshots all stories found in the manifest.
* Each story gets its own screenshot named after its ID.
* If the manifest doesn't exist, it will be bootstrapped automatically.
* Screenshots all stories. JS tells us which story to screenshot and when —
* the test thread just reacts to notifyStoryReady() calls.
*/
@Test
fun screenshotAllStories() {
val context = InstrumentationRegistry.getInstrumentation().targetContext
val externalDir = context.getExternalFilesDir("screenshots")
val manifestFile = File(externalDir, StorybookRegistry.STORIES_FILE_NAME)

if (!manifestFile.exists()) {
Log.d(TAG, "Manifest not found, bootstrapping...")
bootstrapManifest(manifestFile)
}

val allStories = StorybookRegistry.getStoriesFromFile(externalDir!!)
val stories = allStories.filter { shouldScreenshotStory(it) }
val instrumentation = InstrumentationRegistry.getInstrumentation()

Log.d(TAG, "Found ${allStories.size} stories, ${stories.size} after filtering")
assertTrue("No stories found in manifest", stories.isNotEmpty())
StorybookRegistry.prepareForRun()

var successCount = 0
var failureCount = 0
val failures = mutableListOf<String>()

for (story in stories) {
try {
screenshotStory(story)
successCount++
} catch (e: Exception) {
failureCount++
val errorMsg = "${story.title}/${story.name}: ${e.message}"
failures.add(errorMsg)
Log.e(TAG, "Failed to screenshot story: $errorMsg", e)
mountSurface { view ->
// React to stories as JS renders them, until allStoriesDone() is called.
while (true) {
StorybookRegistry.prepareForNextStory()
val storyId = StorybookRegistry.awaitStoryReady(getLoadTimeoutMs())

if (storyId == null) {
// allStoriesDone() was called — JS has finished.
Log.d(TAG, "All stories done")
break
}

try {
// Two frames so Fabric's native view mutations are fully applied.
waitTwoFrames()

val screenshotName = storyId.replace("--", "_")
instrumentation.runOnMainSync {
// view.draw(canvas) can't capture children that have hardware display
// lists. Force the entire tree to software so draw() sees all content.
setLayerTypeSoftwareRecursively(view)
Screenshot.snap(view).setName(screenshotName).record()
}
Log.d(TAG, "Screenshot captured: $screenshotName")
} catch (e: Exception) {
failures.add("$storyId: ${e.message}")
Log.e(TAG, "Failed to screenshot story: $storyId", e)
} finally {
// Resolve the notifyStoryReady() Promise so JS can render the next story.
StorybookRegistry.resolveCurrentStory()
}
}
}

Log.d(TAG, "Screenshot results: $successCount passed, $failureCount failed")
if (failures.isNotEmpty()) {
Log.e(TAG, "Failed stories:\n${failures.joinToString("\n")}")
}

Log.d(TAG, "${failures.size} stories failed")
assertTrue(
"Some stories failed to screenshot: ${failures.joinToString(", ")}",
failures.isEmpty()
)
}

private fun screenshotStory(storyInfo: StoryInfo) {
val storyName = storyInfo.toStoryName()
Log.d(TAG, "Screenshotting: $storyName (id: ${storyInfo.id})")

StorybookRegistry.prepareForNextStory()
renderStory(storyName) { view ->
StorybookRegistry.awaitStoryReady(getLoadTimeoutMs())
val screenshotName = storyInfo.id.replace("--", "_")
Screenshot.snap(view).setName(screenshotName).record()
Log.d(TAG, "Screenshot captured: $screenshotName")
}
}

private fun bootstrapManifest(manifestFile: File) {
Log.d(TAG, "Launching StoryRenderer to generate manifest...")
renderStory(BOOTSTRAP_STORY_NAME) {
waitForManifestFile(manifestFile)
}
Log.d(TAG, "Bootstrap complete")
}

/**
* Renders the given story name into a view, calls [onRendered] with that view,
* then tears down. Handles both old arch (ReactRootView) and new arch (ReactSurface).
* Mounts a single React surface for the whole test run, calls [onMounted] with
* the view, then tears down. Handles both old arch (ReactRootView) and new arch
* (ReactSurface). No props are passed — JS drives itself.
*/
private fun renderStory(storyName: String, onRendered: (view: View) -> Unit) {
private fun mountSurface(onMounted: (view: View) -> Unit) {
val instrumentation = InstrumentationRegistry.getInstrumentation()
val app = instrumentation.targetContext.applicationContext as ReactApplication
val props = Bundle().apply { putString("storyName", storyName) }

val reactHost = app.reactHost
if (reactHost != null) {
// New arch (Fabric/bridgeless): ReactHost + ReactSurface.
// Fabric won't commit its render tree until the surface's host view is parented
// to a real Window. Test processes don't have an Activity window, so we attach
// via WindowManager using TYPE_APPLICATION_OVERLAY (requires SYSTEM_ALERT_WINDOW).
// Wrap with the app theme so AppCompat widgets (e.g. Switch) resolve styled attrs.
val context = ContextThemeWrapper(
instrumentation.targetContext,
instrumentation.targetContext.applicationInfo.theme
)
val surface = reactHost.createSurface(
context,
getMainComponentName(),
props
Bundle()
)

val view = surface.view
Expand All @@ -186,19 +161,18 @@ abstract class BaseStoryScreenshotTest {
)

instrumentation.runOnMainSync {
// Force software rendering so Screenshot.snap() can capture via draw(canvas).
// WindowManager views are hardware-accelerated by default; GPU content is
// invisible to a software canvas.
view.setLayerType(View.LAYER_TYPE_SOFTWARE, null)
wm.addView(view, params)
surface.start()
}

onRendered(view)

instrumentation.runOnMainSync {
surface.stop()
wm.removeView(view)
try {
onMounted(view)
} finally {
instrumentation.runOnMainSync {
surface.stop()
wm.removeView(view)
}
}
} else {
// Old arch: ReactRootView + ReactInstanceManager (deprecated API).
Expand All @@ -208,34 +182,40 @@ abstract class BaseStoryScreenshotTest {
@Suppress("DEPRECATION")
val reactInstanceManager = app.reactNativeHost.reactInstanceManager

// ReactRootView.startReactApplication() checks isOnUiThread() internally.
instrumentation.runOnMainSync {
rootView.startReactApplication(reactInstanceManager, getMainComponentName(), props)
rootView.startReactApplication(reactInstanceManager, getMainComponentName(), Bundle())
}

// setupView().layout() calls measure()+layout() at the fixed dimensions, which
// triggers onMeasure() → attachToReactInstanceManager() on the ReactRootView.
ViewHelpers.setupView(rootView)
.setExactWidthPx(SCREEN_WIDTH_PX)
.setExactHeightPx(SCREEN_HEIGHT_PX)
.layout()

onRendered(rootView)

instrumentation.runOnMainSync { rootView.unmountReactApplication() }
try {
onMounted(rootView)
} finally {
instrumentation.runOnMainSync { rootView.unmountReactApplication() }
}
}
}

private fun waitForManifestFile(manifestFile: File) {
val deadline = System.currentTimeMillis() + getBootstrapTimeoutMs()
while (!manifestFile.exists() && System.currentTimeMillis() < deadline) {
Thread.sleep(100)
private fun setLayerTypeSoftwareRecursively(view: View) {
view.setLayerType(View.LAYER_TYPE_SOFTWARE, null)
if (view is ViewGroup) {
for (i in 0 until view.childCount) {
setLayerTypeSoftwareRecursively(view.getChildAt(i))
}
}
if (!manifestFile.exists()) {
throw IllegalStateException(
"Manifest file did not appear within ${getBootstrapTimeoutMs()}ms. " +
"Make sure configure(view) is called in your app and the StoryRenderer is registered."
)
}

private fun waitTwoFrames() {
val instrumentation = InstrumentationRegistry.getInstrumentation()
repeat(2) {
val latch = CountDownLatch(1)
instrumentation.runOnMainSync {
Choreographer.getInstance().postFrameCallback { latch.countDown() }
}
latch.await(1000, TimeUnit.MILLISECONDS)
}
}
}
Loading
Loading