Skip to content
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
package com.testapp

import android.Manifest
import android.graphics.PixelFormat
import android.os.Handler
import android.os.Looper
import android.view.Choreographer
import android.view.ContextThemeWrapper
import android.view.View
import android.view.ViewGroup
import android.view.WindowManager
import androidx.test.ext.junit.runners.AndroidJUnit4
import androidx.test.platform.app.InstrumentationRegistry
import androidx.test.rule.GrantPermissionRule
import com.facebook.react.ReactApplication
import com.facebook.testing.screenshot.Screenshot
import org.junit.Rule
import org.junit.Test
import org.junit.runner.RunWith
import java.util.concurrent.CountDownLatch

@RunWith(AndroidJUnit4::class)
class SimpleComponentScreenshotTest {

@get:Rule
val permissionRule: GrantPermissionRule = GrantPermissionRule.grant(
Manifest.permission.WRITE_EXTERNAL_STORAGE,
Manifest.permission.READ_EXTERNAL_STORAGE,
Manifest.permission.SYSTEM_ALERT_WINDOW
)

@Test
fun screenshotSimpleComponent() {
val instrumentation = InstrumentationRegistry.getInstrumentation()
val app = instrumentation.targetContext.applicationContext as ReactApplication
val reactHost = app.reactHost!!

val context = ContextThemeWrapper(
instrumentation.targetContext,
instrumentation.targetContext.applicationInfo.theme
)
val surface = reactHost.createSurface(context, "SimpleComponent", null)
val view = surface.view
?: throw IllegalStateException("ReactSurface returned a null view")

val wm = instrumentation.targetContext
.getSystemService(android.content.Context.WINDOW_SERVICE) as WindowManager
val params = WindowManager.LayoutParams(
1080, 1920,
WindowManager.LayoutParams.TYPE_APPLICATION_OVERLAY,
WindowManager.LayoutParams.FLAG_NOT_FOCUSABLE,
PixelFormat.TRANSLUCENT
)

instrumentation.runOnMainSync {
view.setLayerType(View.LAYER_TYPE_SOFTWARE, null)
wm.addView(view, params)
surface.start()
}

waitTwoFrames()

instrumentation.runOnMainSync {
setLayerTypeSoftwareRecursively(view)
Screenshot.snap(view).setName("simple_component").record()
}

instrumentation.runOnMainSync {
surface.stop()
wm.removeView(view)
Comment on lines +55 to +70
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not deterministic

}
}

private fun waitTwoFrames() {
repeat(2) {
val latch = CountDownLatch(1)
Handler(Looper.getMainLooper()).post {
Choreographer.getInstance().postFrameCallback { latch.countDown() }
}
latch.await()
}
}

private fun setLayerTypeSoftwareRecursively(view: View) {
view.setLayerType(View.LAYER_TYPE_SOFTWARE, null)
if (view is ViewGroup) {
for (i in 0 until view.childCount) {
setLayerTypeSoftwareRecursively(view.getChildAt(i))
}
}
}
}
3 changes: 3 additions & 0 deletions index.js
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@ const { configure, StoryRenderer } = require('rn-storybook-auto-screenshots');
configure(view);
AppRegistry.registerComponent('StoryRenderer', () => StoryRenderer);

const { SimpleComponent } = require('rn-storybook-auto-screenshots');
AppRegistry.registerComponent('SimpleComponent', () => SimpleComponent);

const SimpleTestComponent = () => <View><Text>Hello</Text></View>;
AppRegistry.registerComponent('SimpleTestComponent', () => SimpleTestComponent);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,13 @@ package com.rnstorybookautoscreenshots
import android.Manifest
import android.graphics.PixelFormat
import android.os.Bundle
import android.os.Handler
import android.os.Looper
import android.util.Log
import android.view.Choreographer
import android.view.ContextThemeWrapper
import android.view.View
import android.view.ViewGroup
import android.view.WindowManager
import androidx.test.platform.app.InstrumentationRegistry
import androidx.test.rule.GrantPermissionRule
Expand All @@ -16,10 +20,10 @@ import com.facebook.testing.screenshot.ViewHelpers
import org.junit.Assert.assertTrue
import org.junit.Rule
import org.junit.Test
import java.io.File
import java.util.concurrent.CountDownLatch

/**
* Base screenshot test that automatically discovers and tests all Storybook stories.
* Base screenshot test that automatically renders and screenshots all Storybook stories.
*
* Extend this class in your app's androidTest directory:
*
Expand All @@ -28,17 +32,16 @@ import java.io.File
* class StoryScreenshotTest : BaseStoryScreenshotTest()
* ```
*
* This test automatically bootstraps the story manifest if it doesn't exist,
* then creates a screenshot for each story. No manual test methods needed -
* just add stories to Storybook and they get tested automatically.
* A single React surface is mounted for the entire test run. JS drives the story
* loop — rendering each story and calling notifyStoryReady() after React commits.
* The test thread screenshots and then resolves the JS Promise to advance the loop.
* When all stories are done JS calls allStoriesDone() and the test exits.
*/
abstract class BaseStoryScreenshotTest {

companion object {
private const val TAG = "BaseStoryScreenshotTest"
private const val DEFAULT_LOAD_TIMEOUT_MS = 5000L
private const val DEFAULT_BOOTSTRAP_TIMEOUT_MS = 10000L
private const val BOOTSTRAP_STORY_NAME = "__bootstrap__"

private const val SCREEN_WIDTH_PX = 1080
private const val SCREEN_HEIGHT_PX = 1920
Expand All @@ -58,102 +61,85 @@ abstract class BaseStoryScreenshotTest {
open fun getMainComponentName(): String = "StoryRenderer"

/**
* Override to customize the React Native load timeout per story.
* Override to customize the per-story timeout.
* Default is 5000ms.
*/
open fun getLoadTimeoutMs(): Long = DEFAULT_LOAD_TIMEOUT_MS

/**
* Override to customize the timeout for manifest bootstrap.
* Default is 10000ms.
*/
open fun getBootstrapTimeoutMs(): Long = DEFAULT_BOOTSTRAP_TIMEOUT_MS

/**
* Override to filter which stories should be screenshotted.
* Override to skip specific stories.
* Return true to include the story, false to skip it.
* Default includes all stories.
*/
open fun shouldScreenshotStory(storyInfo: StoryInfo): Boolean = true
open fun shouldScreenshotStory(storyId: String): Boolean = true

/**
* Screenshots all stories found in the manifest.
* Each story gets its own screenshot named after its ID.
* If the manifest doesn't exist, it will be bootstrapped automatically.
* Screenshots all Storybook stories.
*
* Mounts a single StoryRenderer surface. JS iterates through all stories,
* calling notifyStoryReady() after each commit. The test thread screenshots
* and resolves the Promise to let JS advance.
*/
@Test
fun screenshotAllStories() {
val context = InstrumentationRegistry.getInstrumentation().targetContext
val externalDir = context.getExternalFilesDir("screenshots")
val manifestFile = File(externalDir, StorybookRegistry.STORIES_FILE_NAME)

if (!manifestFile.exists()) {
Log.d(TAG, "Manifest not found, bootstrapping...")
bootstrapManifest(manifestFile)
mountSurface { view ->
runStoryLoop(view)
}
}

val allStories = StorybookRegistry.getStoriesFromFile(externalDir!!)
val stories = allStories.filter { shouldScreenshotStory(it) }
private fun runStoryLoop(view: View) {
val instrumentation = InstrumentationRegistry.getInstrumentation()
val failures = mutableListOf<String>()
var successCount = 0

Log.d(TAG, "Found ${allStories.size} stories, ${stories.size} after filtering")
assertTrue("No stories found in manifest", stories.isNotEmpty())
while (true) {
StorybookRegistry.prepareForNextStory()
val storyId = StorybookRegistry.awaitStoryReady(getLoadTimeoutMs()) ?: break

var successCount = 0
var failureCount = 0
val failures = mutableListOf<String>()
if (!shouldScreenshotStory(storyId)) {
Log.d(TAG, "Skipping story: $storyId")
StorybookRegistry.resolveCurrentStory()
continue
}

for (story in stories) {
Log.d(TAG, "Screenshotting: $storyId")
try {
screenshotStory(story)
// Wait for Fabric to apply native mutations before snapping.
waitTwoFrames()
val screenshotName = storyId.replace("--", "_")
instrumentation.runOnMainSync {
setLayerTypeSoftwareRecursively(view)
Screenshot.snap(view).setName(screenshotName).record()
}
Log.d(TAG, "Screenshot captured: $screenshotName")
successCount++
} catch (e: Exception) {
failureCount++
val errorMsg = "${story.title}/${story.name}: ${e.message}"
failures.add(errorMsg)
Log.e(TAG, "Failed to screenshot story: $errorMsg", e)
failures.add("$storyId: ${e.message}")
Log.e(TAG, "Failed to screenshot story: $storyId", e)
} finally {
StorybookRegistry.resolveCurrentStory()
}
}

Log.d(TAG, "Screenshot results: $successCount passed, $failureCount failed")
Log.d(TAG, "Screenshot results: $successCount passed, ${failures.size} failed")
if (failures.isNotEmpty()) {
Log.e(TAG, "Failed stories:\n${failures.joinToString("\n")}")
}

assertTrue("No stories were screenshotted", successCount > 0)
assertTrue(
"Some stories failed to screenshot: ${failures.joinToString(", ")}",
failures.isEmpty()
)
}

private fun screenshotStory(storyInfo: StoryInfo) {
val storyName = storyInfo.toStoryName()
Log.d(TAG, "Screenshotting: $storyName (id: ${storyInfo.id})")

StorybookRegistry.prepareForNextStory()
renderStory(storyName) { view ->
StorybookRegistry.awaitStoryReady(getLoadTimeoutMs())
val screenshotName = storyInfo.id.replace("--", "_")
Screenshot.snap(view).setName(screenshotName).record()
Log.d(TAG, "Screenshot captured: $screenshotName")
}
}

private fun bootstrapManifest(manifestFile: File) {
Log.d(TAG, "Launching StoryRenderer to generate manifest...")
renderStory(BOOTSTRAP_STORY_NAME) {
waitForManifestFile(manifestFile)
}
Log.d(TAG, "Bootstrap complete")
}

/**
* Renders the given story name into a view, calls [onRendered] with that view,
* then tears down. Handles both old arch (ReactRootView) and new arch (ReactSurface).
* Mounts the StoryRenderer surface, calls [onMounted] with the view, then tears down.
* Handles both new arch (ReactHost/ReactSurface) and old arch (ReactRootView).
*/
private fun renderStory(storyName: String, onRendered: (view: View) -> Unit) {
private fun mountSurface(onMounted: (view: View) -> Unit) {
val instrumentation = InstrumentationRegistry.getInstrumentation()
val app = instrumentation.targetContext.applicationContext as ReactApplication
val props = Bundle().apply { putString("storyName", storyName) }

val reactHost = app.reactHost
if (reactHost != null) {
Expand All @@ -169,7 +155,7 @@ abstract class BaseStoryScreenshotTest {
val surface = reactHost.createSurface(
context,
getMainComponentName(),
props
Bundle()
)

val view = surface.view
Expand Down Expand Up @@ -198,7 +184,7 @@ abstract class BaseStoryScreenshotTest {
surface.start()
}

onRendered(view)
onMounted(view)

instrumentation.runOnMainSync {
surface.stop()
Expand All @@ -214,7 +200,7 @@ abstract class BaseStoryScreenshotTest {

// ReactRootView.startReactApplication() checks isOnUiThread() internally.
instrumentation.runOnMainSync {
rootView.startReactApplication(reactInstanceManager, getMainComponentName(), props)
rootView.startReactApplication(reactInstanceManager, getMainComponentName(), Bundle())
}

// setupView().layout() calls measure()+layout() at the fixed dimensions, which
Expand All @@ -224,22 +210,43 @@ abstract class BaseStoryScreenshotTest {
.setExactHeightPx(SCREEN_HEIGHT_PX)
.layout()

onRendered(rootView)
onMounted(rootView)

instrumentation.runOnMainSync { rootView.unmountReactApplication() }
}
}

private fun waitForManifestFile(manifestFile: File) {
val deadline = System.currentTimeMillis() + getBootstrapTimeoutMs()
while (!manifestFile.exists() && System.currentTimeMillis() < deadline) {
Thread.sleep(100)
/**
* Recursively sets LAYER_TYPE_SOFTWARE on a view and all its descendants.
*
* view.draw(canvas) cannot capture children that have hardware display lists.
* Fabric child views in a hardware-accelerated WindowManager window get hardware
* display lists by default, so they render blank into a software canvas.
* Walking the tree and forcing software layers ensures draw() sees all content.
*/
private fun setLayerTypeSoftwareRecursively(view: View) {
view.setLayerType(View.LAYER_TYPE_SOFTWARE, null)
if (view is ViewGroup) {
for (i in 0 until view.childCount) {
setLayerTypeSoftwareRecursively(view.getChildAt(i))
}
}
if (!manifestFile.exists()) {
throw IllegalStateException(
"Manifest file did not appear within ${getBootstrapTimeoutMs()}ms. " +
"Make sure configure(view) is called in your app and the StoryRenderer is registered."
)
}

/**
* Waits for two Choreographer frames on the main thread.
*
* After useEffect fires (React commit), Fabric still needs to apply its
* native mutations in the next frame(s). Waiting two frames ensures the
* shadow tree is fully flushed to native views before we screenshot.
*/
private fun waitTwoFrames() {
repeat(2) {
val latch = CountDownLatch(1)
Handler(Looper.getMainLooper()).post {
Choreographer.getInstance().postFrameCallback { latch.countDown() }
}
latch.await()
}
}
}
Loading
Loading