diff --git a/packages/example/pubspec.lock b/packages/example/pubspec.lock index e41bd066..e1df4131 100644 --- a/packages/example/pubspec.lock +++ b/packages/example/pubspec.lock @@ -69,10 +69,10 @@ packages: dependency: transitive description: name: characters - sha256: f71061c654a3380576a52b451dd5532377954cf9dbd272a78fc8479606670803 + sha256: faf38497bda5ead2a8c7615f4f7939df04333478bf32e4173fcb06d428b5716b url: "https://pub.dev" source: hosted - version: "1.4.0" + version: "1.4.1" clock: dependency: transitive description: @@ -454,26 +454,26 @@ packages: dependency: transitive description: name: matcher - sha256: dc58c723c3c24bf8d3e2d3ad3f2f9d7bd9cf43ec6feaa64181775e60190153f2 + sha256: "12956d0ad8390bbcc63ca2e1469c0619946ccb52809807067a7020d57e647aa6" url: "https://pub.dev" source: hosted - version: "0.12.17" + version: "0.12.18" material_color_utilities: dependency: transitive description: name: material_color_utilities - sha256: f7142bb1154231d7ea5f96bc7bde4bda2a0945d2806bb11670e30b850d56bdec + sha256: "9c337007e82b1889149c82ed242ed1cb24a66044e30979c44912381e9be4c48b" url: "https://pub.dev" source: hosted - version: "0.11.1" + version: "0.13.0" meta: dependency: transitive description: name: meta - sha256: e3641ec5d63ebf0d9b41bd43201a66e3fc79a65db5f61fc181f04cd27aab950c + sha256: "23f08335362185a5ea2ad3a4e597f1375e78bce8a040df5c600c8d3552ef2394" url: "https://pub.dev" source: hosted - version: "1.16.0" + version: "1.17.0" mime: dependency: transitive description: @@ -611,10 +611,10 @@ packages: dependency: transitive description: name: test_api - sha256: "522f00f556e73044315fa4585ec3270f1808a4b186c936e612cab0b565ff1e00" + sha256: "93167629bfc610f71560ab9312acdda4959de4df6fac7492c89ff0d3886f6636" url: "https://pub.dev" source: hosted - version: "0.7.6" + version: "0.7.9" typed_data: dependency: transitive description: diff --git a/packages/google_mlkit_barcode_scanning/android/build.gradle b/packages/google_mlkit_barcode_scanning/android/build.gradle index f305510f..941c71ce 100644 --- a/packages/google_mlkit_barcode_scanning/android/build.gradle +++ b/packages/google_mlkit_barcode_scanning/android/build.gradle @@ -2,6 +2,7 @@ group = "com.google_mlkit_barcode_scanning" version = "1.0" buildscript { + ext.kotlin_version = "2.2.20" repositories { google() mavenCentral() @@ -9,6 +10,7 @@ buildscript { dependencies { classpath("com.android.tools.build:gradle:8.13.0") + classpath("org.jetbrains.kotlin:kotlin-gradle-plugin:2.2.20") } } @@ -20,6 +22,7 @@ rootProject.allprojects { } apply plugin: "com.android.library" +apply plugin: "kotlin-android" android { namespace = "com.google_mlkit_barcode_scanning" @@ -31,11 +34,19 @@ android { targetCompatibility = JavaVersion.VERSION_11 } + kotlinOptions { + jvmTarget = '11' + } + + sourceSets { + main.java.srcDirs += "src/main/kotlin" + } + defaultConfig { minSdk = 21 } dependencies { - implementation("com.google.mlkit:barcode-scanning:17.3.0") + implementation("com.google.mlkit:barcode-scanning:17.3.0") } } diff --git a/packages/google_mlkit_barcode_scanning/android/gradle/wrapper/gradle-wrapper.properties b/packages/google_mlkit_barcode_scanning/android/gradle/wrapper/gradle-wrapper.properties new file mode 100644 index 00000000..128196a7 --- /dev/null +++ b/packages/google_mlkit_barcode_scanning/android/gradle/wrapper/gradle-wrapper.properties @@ -0,0 +1,7 @@ +distributionBase=GRADLE_USER_HOME +distributionPath=wrapper/dists +distributionUrl=https\://services.gradle.org/distributions/gradle-9.0-milestone-1-bin.zip +networkTimeout=10000 +validateDistributionUrl=true +zipStoreBase=GRADLE_USER_HOME +zipStorePath=wrapper/dists diff --git a/packages/google_mlkit_barcode_scanning/android/src/main/java/com/google_mlkit_barcode_scanning/BarcodeScanner.java b/packages/google_mlkit_barcode_scanning/android/src/main/java/com/google_mlkit_barcode_scanning/BarcodeScanner.java deleted file mode 100644 index 42fda756..00000000 --- a/packages/google_mlkit_barcode_scanning/android/src/main/java/com/google_mlkit_barcode_scanning/BarcodeScanner.java +++ /dev/null @@ -1,223 +0,0 @@ -package com.google_mlkit_barcode_scanning; - -import android.content.Context; -import android.graphics.Point; -import android.graphics.Rect; - -import androidx.annotation.NonNull; -import androidx.annotation.Nullable; - -import com.google.mlkit.vision.barcode.common.Barcode; -import com.google.mlkit.vision.barcode.BarcodeScannerOptions; -import com.google.mlkit.vision.barcode.BarcodeScanning; -import com.google.mlkit.vision.common.InputImage; -import com.google_mlkit_commons.InputImageConverter; - -import java.util.ArrayList; -import java.util.Collections; -import java.util.HashMap; -import java.util.List; -import java.util.Map; - -import io.flutter.plugin.common.MethodCall; -import io.flutter.plugin.common.MethodChannel; - -public class BarcodeScanner implements MethodChannel.MethodCallHandler { - private static final String START = "vision#startBarcodeScanner"; - private static final String CLOSE = "vision#closeBarcodeScanner"; - - private final Context context; - private final Map instances = new HashMap<>(); - - public BarcodeScanner(Context context) { - this.context = context; - } - - @Override - public void onMethodCall(@NonNull MethodCall call, @NonNull MethodChannel.Result result) { - String method = call.method; - switch (method) { - case START: - handleDetection(call, result); - break; - case CLOSE: - closeDetector(call); - result.success(null); - break; - default: - result.notImplemented(); - break; - } - } - - private com.google.mlkit.vision.barcode.BarcodeScanner initialize(MethodCall call) { - List formatList = call.argument("formats"); - BarcodeScannerOptions barcodeScannerOptions; - if (formatList.size() > 1) { - int[] array = new int[formatList.size()]; - for (int i = 1; i < formatList.size(); i++) { - array[i] = formatList.get(i); - } - barcodeScannerOptions = new BarcodeScannerOptions.Builder().setBarcodeFormats(formatList.get(0), array).build(); - } else { - barcodeScannerOptions = new BarcodeScannerOptions.Builder().setBarcodeFormats(formatList.get(0)).build(); - } - return BarcodeScanning.getClient(barcodeScannerOptions); - } - - private void handleDetection(MethodCall call, final MethodChannel.Result result) { - Map imageData = call.argument("imageData"); - InputImage inputImage = InputImageConverter.getInputImageFromData(imageData, context, result); - if (inputImage == null) return; - - String id = call.argument("id"); - com.google.mlkit.vision.barcode.BarcodeScanner barcodeScanner = instances.get(id); - if (barcodeScanner == null) { - barcodeScanner = initialize(call); - instances.put(id, barcodeScanner); - } - - barcodeScanner.process(inputImage) - .addOnSuccessListener(barcodes -> { - List> barcodeList = new ArrayList<>(barcodes.size()); - for (Barcode barcode : barcodes) { - - Map barcodeMap = new HashMap<>(); - int valueType = barcode.getValueType(); - barcodeMap.put("type", valueType); - barcodeMap.put("format", barcode.getFormat()); - barcodeMap.put("rawValue", barcode.getRawValue()); - barcodeMap.put("rawBytes", barcode.getRawBytes()); - barcodeMap.put("displayValue", barcode.getDisplayValue()); - barcodeMap.put("rect", getBoundingPoints(barcode.getBoundingBox())); - Point[] cornerPoints = barcode.getCornerPoints(); - List> points = new ArrayList<>(); - addPoints(cornerPoints, points); - barcodeMap.put("points", points); - switch (valueType) { - case Barcode.TYPE_UNKNOWN: - case Barcode.TYPE_ISBN: - case Barcode.TYPE_PRODUCT: - case Barcode.TYPE_TEXT: - break; - case Barcode.TYPE_WIFI: - barcodeMap.put("ssid", barcode.getWifi().getSsid()); - barcodeMap.put("password", barcode.getWifi().getPassword()); - barcodeMap.put("encryption", barcode.getWifi().getEncryptionType()); - break; - case Barcode.TYPE_URL: - barcodeMap.put("title", barcode.getUrl().getTitle()); - barcodeMap.put("url", barcode.getUrl().getUrl()); - break; - case Barcode.TYPE_EMAIL: - barcodeMap.put("address", barcode.getEmail().getAddress()); - barcodeMap.put("body", barcode.getEmail().getBody()); - barcodeMap.put("subject", barcode.getEmail().getSubject()); - barcodeMap.put("emailType", barcode.getEmail().getType()); - break; - case Barcode.TYPE_PHONE: - barcodeMap.put("number", barcode.getPhone().getNumber()); - barcodeMap.put("phoneType", barcode.getPhone().getType()); - break; - case Barcode.TYPE_SMS: - barcodeMap.put("message", barcode.getSms().getMessage()); - barcodeMap.put("number", barcode.getSms().getPhoneNumber()); - break; - case Barcode.TYPE_GEO: - barcodeMap.put("latitude", barcode.getGeoPoint().getLat()); - barcodeMap.put("longitude", barcode.getGeoPoint().getLng()); - break; - case Barcode.TYPE_DRIVER_LICENSE: - barcodeMap.put("addressCity", barcode.getDriverLicense().getAddressCity()); - barcodeMap.put("addressState", barcode.getDriverLicense().getAddressState()); - barcodeMap.put("addressZip", barcode.getDriverLicense().getAddressZip()); - barcodeMap.put("addressStreet", barcode.getDriverLicense().getAddressStreet()); - barcodeMap.put("issueDate", barcode.getDriverLicense().getIssueDate()); - barcodeMap.put("birthDate", barcode.getDriverLicense().getBirthDate()); - barcodeMap.put("expiryDate", barcode.getDriverLicense().getExpiryDate()); - barcodeMap.put("gender", barcode.getDriverLicense().getGender()); - barcodeMap.put("licenseNumber", barcode.getDriverLicense().getLicenseNumber()); - barcodeMap.put("firstName", barcode.getDriverLicense().getFirstName()); - barcodeMap.put("lastName", barcode.getDriverLicense().getLastName()); - barcodeMap.put("country", barcode.getDriverLicense().getIssuingCountry()); - break; - case Barcode.TYPE_CONTACT_INFO: - barcodeMap.put("firstName", barcode.getContactInfo().getName().getFirst()); - barcodeMap.put("lastName", barcode.getContactInfo().getName().getLast()); - barcodeMap.put("formattedName", barcode.getContactInfo().getName().getFormattedName()); - barcodeMap.put("organization", barcode.getContactInfo().getOrganization()); - List> queries = new ArrayList<>(); - for (Barcode.Address address : barcode.getContactInfo().getAddresses()) { - Map addressMap = new HashMap<>(); - addressMap.put("addressType", address.getType()); - List addressLines = new ArrayList<>(); - Collections.addAll(addressLines, address.getAddressLines()); - addressMap.put("addressLines", addressLines); - queries.add(addressMap); - } - barcodeMap.put("addresses", queries); - queries = new ArrayList<>(); - for (Barcode.Phone phone : barcode.getContactInfo().getPhones()) { - Map phoneMap = new HashMap<>(); - phoneMap.put("number", phone.getNumber()); - phoneMap.put("phoneType", phone.getType()); - queries.add(phoneMap); - } - barcodeMap.put("phones", queries); - queries = new ArrayList<>(); - for (Barcode.Email email : barcode.getContactInfo().getEmails()) { - Map emailMap = new HashMap<>(); - emailMap.put("address", email.getAddress()); - emailMap.put("body", email.getBody()); - emailMap.put("subject", email.getSubject()); - emailMap.put("emailType", email.getType()); - queries.add(emailMap); - } - barcodeMap.put("emails", queries); - List urls = new ArrayList<>(barcode.getContactInfo().getUrls()); - barcodeMap.put("urls", urls); - break; - case Barcode.TYPE_CALENDAR_EVENT: - barcodeMap.put("description", barcode.getCalendarEvent().getDescription()); - barcodeMap.put("location", barcode.getCalendarEvent().getLocation()); - barcodeMap.put("status", barcode.getCalendarEvent().getStatus()); - barcodeMap.put("summary", barcode.getCalendarEvent().getSummary()); - barcodeMap.put("organizer", barcode.getCalendarEvent().getOrganizer()); - barcodeMap.put("start", barcode.getCalendarEvent().getStart().getRawValue()); - barcodeMap.put("end", barcode.getCalendarEvent().getEnd().getRawValue()); - break; - } - barcodeList.add(barcodeMap); - } - result.success(barcodeList); - }) - .addOnFailureListener(e -> result.error("BarcodeDetectorError", e.toString(), null)); - } - - private void addPoints(Point[] cornerPoints, List> points) { - for (Point point : cornerPoints) { - Map p = new HashMap<>(); - p.put("x", point.x); - p.put("y", point.y); - points.add(p); - } - } - - private Map getBoundingPoints(@Nullable Rect rect) { - Map frame = new HashMap<>(); - if (rect == null) return frame; - frame.put("left", rect.left); - frame.put("right", rect.right); - frame.put("top", rect.top); - frame.put("bottom", rect.bottom); - return frame; - } - - private void closeDetector(MethodCall call) { - String id = call.argument("id"); - com.google.mlkit.vision.barcode.BarcodeScanner barcodeScanner = instances.get(id); - if (barcodeScanner == null) return; - barcodeScanner.close(); - instances.remove(id); - } -} diff --git a/packages/google_mlkit_barcode_scanning/android/src/main/java/com/google_mlkit_barcode_scanning/GoogleMlKitBarcodeScanningPlugin.java b/packages/google_mlkit_barcode_scanning/android/src/main/java/com/google_mlkit_barcode_scanning/GoogleMlKitBarcodeScanningPlugin.java deleted file mode 100644 index 23e70774..00000000 --- a/packages/google_mlkit_barcode_scanning/android/src/main/java/com/google_mlkit_barcode_scanning/GoogleMlKitBarcodeScanningPlugin.java +++ /dev/null @@ -1,22 +0,0 @@ -package com.google_mlkit_barcode_scanning; - -import androidx.annotation.NonNull; - -import io.flutter.embedding.engine.plugins.FlutterPlugin; -import io.flutter.plugin.common.MethodChannel; - -public class GoogleMlKitBarcodeScanningPlugin implements FlutterPlugin { - private MethodChannel channel; - private static final String channelName = "google_mlkit_barcode_scanning"; - - @Override - public void onAttachedToEngine(@NonNull FlutterPluginBinding flutterPluginBinding) { - channel = new MethodChannel(flutterPluginBinding.getBinaryMessenger(), channelName); - channel.setMethodCallHandler(new BarcodeScanner(flutterPluginBinding.getApplicationContext())); - } - - @Override - public void onDetachedFromEngine(@NonNull FlutterPluginBinding binding) { - channel.setMethodCallHandler(null); - } -} diff --git a/packages/google_mlkit_barcode_scanning/android/src/main/kotlin/com/google_mlkit_barcode_scanning/BarcodeScanner.kt b/packages/google_mlkit_barcode_scanning/android/src/main/kotlin/com/google_mlkit_barcode_scanning/BarcodeScanner.kt new file mode 100644 index 00000000..85f72f58 --- /dev/null +++ b/packages/google_mlkit_barcode_scanning/android/src/main/kotlin/com/google_mlkit_barcode_scanning/BarcodeScanner.kt @@ -0,0 +1,233 @@ +package com.google_mlkit_barcode_scanning + +import android.content.Context +import android.graphics.Point +import android.graphics.Rect +import com.google.mlkit.vision.barcode.BarcodeScannerOptions +import com.google.mlkit.vision.barcode.BarcodeScanning +import com.google.mlkit.vision.barcode.common.Barcode +import com.google_mlkit_commons.InputImageConverter +import io.flutter.plugin.common.MethodCall +import io.flutter.plugin.common.MethodChannel + +class BarcodeScanner( + private val context: Context, +) : MethodChannel.MethodCallHandler { + private val instances = mutableMapOf() + + companion object { + private const val START = "vision#startBarcodeScanner" + private const val CLOSE = "vision#closeBarcodeScanner" + } + + override fun onMethodCall( + call: MethodCall, + result: MethodChannel.Result, + ) { + when (call.method) { + START -> { + handleDetection(call, result) + } + + CLOSE -> { + closeDetector(call) + result.success(null) + } + + else -> { + result.notImplemented() + } + } + } + + private fun initialize(call: MethodCall): com.google.mlkit.vision.barcode.BarcodeScanner { + val formatList = call.argument>("formats")!! + val options = + if (formatList.size > 1) { + val rest = formatList.drop(1).toIntArray() + BarcodeScannerOptions + .Builder() + .setBarcodeFormats(formatList[0], *rest) + .build() + } else { + BarcodeScannerOptions + .Builder() + .setBarcodeFormats(formatList[0]) + .build() + } + + return BarcodeScanning.getClient(options) + } + + private fun handleDetection( + call: MethodCall, + result: MethodChannel.Result, + ) { + val imageData = + call.argument>("imageData") ?: run { + result.error("BarcodeDetectorError", "imageData is null", null) + return + } + + val inputImage = InputImageConverter.getInputImageFromData(imageData, context, result) ?: return + + val id = call.argument("id")!! + val scanner = instances.getOrPut(id) { initialize(call) } + + scanner + .process(inputImage) + .addOnSuccessListener { barcodes -> + val barcodeList = + barcodes.map { barcode -> + buildMap { + val valueType = barcode.valueType + put("type", valueType) + put("format", barcode.format) + put("rawValue", barcode.rawValue) + put("rawBytes", barcode.rawBytes) + put("displayValue", barcode.displayValue) + put("rect", getBoundingPoints(barcode.boundingBox)) + put("points", barcode.cornerPoints?.let { getPoints(it) } ?: emptyList()) + + when (valueType) { + Barcode.TYPE_UNKNOWN, + Barcode.TYPE_ISBN, + Barcode.TYPE_PRODUCT, + Barcode.TYPE_TEXT, + -> { + Unit + } + + Barcode.TYPE_WIFI -> { + barcode.wifi?.let { + put("ssid", it.ssid) + put("password", it.password) + put("encryption", it.encryptionType) + } + } + + Barcode.TYPE_URL -> { + barcode.url?.let { + put("title", it.title) + put("url", it.url) + } + } + + Barcode.TYPE_EMAIL -> { + barcode.email?.let { + put("address", it.address) + put("body", it.body) + put("subject", it.subject) + put("emailType", it.type) + } + } + + Barcode.TYPE_PHONE -> { + barcode.phone?.let { + put("number", it.number) + put("phoneType", it.type) + } + } + + Barcode.TYPE_SMS -> { + barcode.sms?.let { + put("message", it.message) + put("number", it.phoneNumber) + } + } + + Barcode.TYPE_GEO -> { + barcode.geoPoint?.let { + put("latitude", it.lat) + put("longitude", it.lng) + } + } + + Barcode.TYPE_DRIVER_LICENSE -> { + barcode.driverLicense?.let { + put("addressCity", it.addressCity) + put("addressState", it.addressState) + put("addressZip", it.addressZip) + put("addressStreet", it.addressStreet) + put("issueDate", it.issueDate) + put("birthDate", it.birthDate) + put("expiryDate", it.expiryDate) + put("gender", it.gender) + put("licenseNumber", it.licenseNumber) + put("firstName", it.firstName) + put("lastName", it.lastName) + put("country", it.issuingCountry) + } + } + + Barcode.TYPE_CONTACT_INFO -> { + barcode.contactInfo?.let { contact -> + put("firstName", contact.name?.first) + put("lastName", contact.name?.last) + put("formattedName", contact.name?.formattedName) + put("organization", contact.organization) + put( + "addresses", + contact.addresses.map { address -> + mapOf( + "addressType" to address.type, + "addressLines" to address.addressLines.toList(), + ) + }, + ) + put( + "phones", + contact.phones.map { phone -> + mapOf("number" to phone.number, "phoneType" to phone.type) + }, + ) + put( + "emails", + contact.emails.map { email -> + mapOf( + "address" to email.address, + "body" to email.body, + "subject" to email.subject, + "emailType" to email.type, + ) + }, + ) + put("urls", contact.urls.toList()) + } + } + + Barcode.TYPE_CALENDAR_EVENT -> { + barcode.calendarEvent?.let { + put("description", it.description) + put("location", it.location) + put("status", it.status) + put("summary", it.summary) + put("organizer", it.organizer) + put("start", it.start?.rawValue) + put("end", it.end?.rawValue) + } + } + } + } + } + result.success(barcodeList) + }.addOnFailureListener { e -> + result.error("BarcodeDetectorError", e.toString(), null) + } + } + + private fun getPoints(cornerPoints: Array) = + cornerPoints.map { + mapOf("x" to it.x, "y" to it.y) + } + + private fun getBoundingPoints(rect: Rect?) = + rect?.let { + mapOf("left" to it.left, "right" to it.right, "top" to it.top, "bottom" to it.bottom) + } ?: emptyMap() + + private fun closeDetector(call: MethodCall) { + val id = call.argument("id") ?: return + instances.remove(id)?.close() + } +} diff --git a/packages/google_mlkit_barcode_scanning/android/src/main/kotlin/com/google_mlkit_barcode_scanning/GoogleMlKitBarcodeScanningPlugin.kt b/packages/google_mlkit_barcode_scanning/android/src/main/kotlin/com/google_mlkit_barcode_scanning/GoogleMlKitBarcodeScanningPlugin.kt new file mode 100644 index 00000000..3a3fe8ad --- /dev/null +++ b/packages/google_mlkit_barcode_scanning/android/src/main/kotlin/com/google_mlkit_barcode_scanning/GoogleMlKitBarcodeScanningPlugin.kt @@ -0,0 +1,31 @@ +package com.google_mlkit_barcode_scanning + +import io.flutter.embedding.engine.plugins.FlutterPlugin +import io.flutter.plugin.common.MethodCall +import io.flutter.plugin.common.MethodChannel + +class GoogleMlKitBarcodeScanningPlugin : + FlutterPlugin, + MethodChannel.MethodCallHandler { + private lateinit var channel: MethodChannel + + companion object { + private const val CHANNEL_NAME = "google_mlkit_barcode_scanning" + } + + override fun onAttachedToEngine(flutterPluginBinding: FlutterPlugin.FlutterPluginBinding) { + channel = MethodChannel(flutterPluginBinding.binaryMessenger, CHANNEL_NAME) + channel.setMethodCallHandler(BarcodeScanner(flutterPluginBinding.applicationContext)) + } + + override fun onDetachedFromEngine(binding: FlutterPlugin.FlutterPluginBinding) { + channel.setMethodCallHandler(null) + } + + override fun onMethodCall( + call: MethodCall, + result: MethodChannel.Result, + ) { + result.notImplemented() + } +} diff --git a/packages/google_mlkit_digital_ink_recognition/android/build.gradle b/packages/google_mlkit_digital_ink_recognition/android/build.gradle index d8e99072..cb4e0624 100644 --- a/packages/google_mlkit_digital_ink_recognition/android/build.gradle +++ b/packages/google_mlkit_digital_ink_recognition/android/build.gradle @@ -2,6 +2,7 @@ group = "com.google_mlkit_digital_ink_recognition" version = "1.0" buildscript { + ext.kotlin_version = "2.2.20" repositories { google() mavenCentral() @@ -9,6 +10,7 @@ buildscript { dependencies { classpath("com.android.tools.build:gradle:8.13.0") + classpath('org.jetbrains.kotlin:kotlin-gradle-plugin:2.2.20') } } @@ -20,6 +22,7 @@ rootProject.allprojects { } apply plugin: "com.android.library" +apply plugin: "kotlin-android" android { namespace = "com.google_mlkit_digital_ink_recognition" @@ -31,6 +34,14 @@ android { targetCompatibility = JavaVersion.VERSION_11 } + kotlinOptions { + jvmTarget = '11' + } + + sourceSets { + main.java.srcDirs += "src/main/kotlin" + } + defaultConfig { minSdk = 21 } diff --git a/packages/google_mlkit_digital_ink_recognition/android/src/main/java/com/google_mlkit_digital_ink_recognition/DigitalInkRecognizer.java b/packages/google_mlkit_digital_ink_recognition/android/src/main/java/com/google_mlkit_digital_ink_recognition/DigitalInkRecognizer.java deleted file mode 100644 index ebd0a45c..00000000 --- a/packages/google_mlkit_digital_ink_recognition/android/src/main/java/com/google_mlkit_digital_ink_recognition/DigitalInkRecognizer.java +++ /dev/null @@ -1,189 +0,0 @@ -package com.google_mlkit_digital_ink_recognition; - -import androidx.annotation.NonNull; - -import com.google.mlkit.common.MlKitException; -import com.google.mlkit.vision.digitalink.recognition.DigitalInkRecognition; -import com.google.mlkit.vision.digitalink.recognition.DigitalInkRecognitionModel; -import com.google.mlkit.vision.digitalink.recognition.DigitalInkRecognitionModelIdentifier; -import com.google.mlkit.vision.digitalink.recognition.DigitalInkRecognizerOptions; -import com.google.mlkit.vision.digitalink.recognition.Ink; -import com.google.mlkit.vision.digitalink.common.RecognitionCandidate; -import com.google.mlkit.vision.digitalink.recognition.RecognitionContext; -import com.google.mlkit.vision.digitalink.common.RecognitionResult; -import com.google.mlkit.vision.digitalink.recognition.WritingArea; -import com.google_mlkit_commons.GenericModelManager; - -import java.util.ArrayList; -import java.util.HashMap; -import java.util.List; -import java.util.Map; - -import io.flutter.plugin.common.MethodCall; -import io.flutter.plugin.common.MethodChannel; - -public class DigitalInkRecognizer implements MethodChannel.MethodCallHandler { - private static final String START = "vision#startDigitalInkRecognizer"; - private static final String CLOSE = "vision#closeDigitalInkRecognizer"; - private static final String MANAGE = "vision#manageInkModels"; - - private final Map instances = new HashMap<>(); - private final GenericModelManager genericModelManager = new GenericModelManager(); - - @Override - public void onMethodCall(@NonNull MethodCall call, @NonNull MethodChannel.Result result) { - String method = call.method; - switch (method) { - case START: - handleDetection(call, result); - break; - case CLOSE: - closeDetector(call); - break; - case MANAGE: - manageModel(call, result); - break; - default: - result.notImplemented(); - break; - } - } - - private void handleDetection(MethodCall call, final MethodChannel.Result result) { - String tag = call.argument("model"); - DigitalInkRecognitionModel model = getModel(tag, result); - if (model == null) - return; - - genericModelManager.isModelDownloaded( - model, - new GenericModelManager.CheckModelIsDownloadedCallback() { - @Override - public void onCheckResult(Boolean isDownloaded) { - if (!isDownloaded) { - result.error("Model Error", "Model has not been downloaded yet ", null); - return; - } - - handleInkDetectionIfModelDownloaded(call, result, model); - } - - @Override - public void onError(Exception e) { - result.error("Model download check failed", e.toString(), e); - } - } - ); - } - - private void handleInkDetectionIfModelDownloaded( - MethodCall call, - final MethodChannel.Result result, - DigitalInkRecognitionModel model - ) { - String id = call.argument("id"); - com.google.mlkit.vision.digitalink.recognition.DigitalInkRecognizer recognizer = instances.get(id); - if (recognizer == null) { - recognizer = DigitalInkRecognition.getClient(DigitalInkRecognizerOptions.builder(model).build()); - instances.put(id, recognizer); - } - - Map inkMap = call.argument("ink"); - List> strokeList = (List>) inkMap.get("strokes"); - Ink.Builder inkBuilder = Ink.builder(); - for (final Map strokeMap : strokeList) { - Ink.Stroke.Builder strokeBuilder = Ink.Stroke.builder(); - List> pointsList = (List>) strokeMap.get("points"); - for (final Map point : pointsList) { - float x = (float) (double) point.get("x"); - float y = (float) (double) point.get("y"); - Object t0 = point.get("t"); - long t; - if (t0 instanceof Integer) { - t = (int) t0; - } else { - t = (long) t0; - } - Ink.Point strokePoint = Ink.Point.create(x, y, t); - strokeBuilder.addPoint(strokePoint); - } - inkBuilder.addStroke(strokeBuilder.build()); - } - Ink ink = inkBuilder.build(); - - RecognitionContext context = null; - Map contextMap = call.argument("context"); - if (contextMap != null) { - RecognitionContext.Builder builder = RecognitionContext.builder(); - String preContext = (String) contextMap.get("preContext"); - if (preContext != null) { - builder.setPreContext(preContext); - } else { - builder.setPreContext(""); - } - - Map writingAreaMap = (Map) contextMap.get("writingArea"); - if (writingAreaMap != null) { - float width = (float) (double) writingAreaMap.get("width"); - float height = (float) (double) writingAreaMap.get("height"); - builder.setWritingArea(new WritingArea(width, height)); - } - - context = builder.build(); - } - - if (context != null) { - recognizer.recognize(ink, context) - .addOnSuccessListener(recognitionResult -> process(recognitionResult, result)) - .addOnFailureListener(e -> result.error("recognition Error", e.toString(), null)); - } else { - recognizer.recognize(ink) - .addOnSuccessListener(recognitionResult -> process(recognitionResult, result)) - .addOnFailureListener(e -> result.error("recognition Error", e.toString(), null)); - } - } - - private void process(RecognitionResult recognitionResult, final MethodChannel.Result result) { - List> candidatesList = new ArrayList<>(recognitionResult.getCandidates().size()); - for (RecognitionCandidate candidate : recognitionResult.getCandidates()) { - Map candidateData = new HashMap<>(); - double score = 0; - if (candidate.getScore() != null) - score = candidate.getScore().doubleValue(); - candidateData.put("text", candidate.getText()); - candidateData.put("score", score); - candidatesList.add(candidateData); - } - result.success(candidatesList); - } - - private void closeDetector(MethodCall call) { - String id = call.argument("id"); - com.google.mlkit.vision.digitalink.recognition.DigitalInkRecognizer recognizer = instances.get(id); - if (recognizer == null) - return; - recognizer.close(); - instances.remove(id); - } - - private void manageModel(MethodCall call, final MethodChannel.Result result) { - String tag = call.argument("model"); - DigitalInkRecognitionModel model = getModel(tag, result); - genericModelManager.manageModel(model, call, result); - } - - private DigitalInkRecognitionModel getModel(String tag, final MethodChannel.Result result) { - DigitalInkRecognitionModelIdentifier modelIdentifier; - try { - modelIdentifier = DigitalInkRecognitionModelIdentifier.fromLanguageTag(tag); - } catch (MlKitException e) { - result.error("Failed to create model identifier", e.toString(), null); - return null; - } - if (modelIdentifier == null) { - result.error("Model Identifier error", "No model was found", null); - return null; - } - return DigitalInkRecognitionModel.builder(modelIdentifier).build(); - } -} \ No newline at end of file diff --git a/packages/google_mlkit_digital_ink_recognition/android/src/main/java/com/google_mlkit_digital_ink_recognition/GoogleMlKitDigitalInkRecognitionPlugin.java b/packages/google_mlkit_digital_ink_recognition/android/src/main/java/com/google_mlkit_digital_ink_recognition/GoogleMlKitDigitalInkRecognitionPlugin.java deleted file mode 100644 index 9166f25f..00000000 --- a/packages/google_mlkit_digital_ink_recognition/android/src/main/java/com/google_mlkit_digital_ink_recognition/GoogleMlKitDigitalInkRecognitionPlugin.java +++ /dev/null @@ -1,22 +0,0 @@ -package com.google_mlkit_digital_ink_recognition; - -import androidx.annotation.NonNull; - -import io.flutter.embedding.engine.plugins.FlutterPlugin; -import io.flutter.plugin.common.MethodChannel; - -public class GoogleMlKitDigitalInkRecognitionPlugin implements FlutterPlugin { - private MethodChannel channel; - private static final String channelName = "google_mlkit_digital_ink_recognizer"; - - @Override - public void onAttachedToEngine(@NonNull FlutterPluginBinding flutterPluginBinding) { - channel = new MethodChannel(flutterPluginBinding.getBinaryMessenger(), channelName); - channel.setMethodCallHandler(new DigitalInkRecognizer()); - } - - @Override - public void onDetachedFromEngine(@NonNull FlutterPluginBinding binding) { - channel.setMethodCallHandler(null); - } -} diff --git a/packages/google_mlkit_digital_ink_recognition/android/src/main/kotlin/com/google_mlkit_digital_ink_recognition/DigitalInkRecognizer.kt b/packages/google_mlkit_digital_ink_recognition/android/src/main/kotlin/com/google_mlkit_digital_ink_recognition/DigitalInkRecognizer.kt new file mode 100644 index 00000000..b4a5e82c --- /dev/null +++ b/packages/google_mlkit_digital_ink_recognition/android/src/main/kotlin/com/google_mlkit_digital_ink_recognition/DigitalInkRecognizer.kt @@ -0,0 +1,182 @@ +package com.google_mlkit_digital_ink_recognition + +import com.google.mlkit.common.MlKitException +import com.google.mlkit.vision.digitalink.common.RecognitionResult +import com.google.mlkit.vision.digitalink.recognition.DigitalInkRecognition +import com.google.mlkit.vision.digitalink.recognition.DigitalInkRecognitionModel +import com.google.mlkit.vision.digitalink.recognition.DigitalInkRecognitionModelIdentifier +import com.google.mlkit.vision.digitalink.recognition.DigitalInkRecognizerOptions +import com.google.mlkit.vision.digitalink.recognition.Ink +import com.google.mlkit.vision.digitalink.recognition.RecognitionContext +import com.google.mlkit.vision.digitalink.recognition.WritingArea +import com.google_mlkit_commons.GenericModelManager +import io.flutter.plugin.common.MethodCall +import io.flutter.plugin.common.MethodChannel + +class DigitalInkRecognizer : MethodChannel.MethodCallHandler { + private val instances = mutableMapOf() + private val genericModelManager = GenericModelManager() + + companion object { + private const val START = "vision#startDigitalInkRecognizer" + private const val CLOSE = "vision#closeDigitalInkRecognizer" + private const val MANAGE = "vision#manageInkModels" + } + + override fun onMethodCall( + call: MethodCall, + result: MethodChannel.Result, + ) { + when (call.method) { + START -> { + handleDetection(call, result) + } + + CLOSE -> { + closeDetector(call) + } + + MANAGE -> { + manageModel(call, result) + } + + else -> { + result.notImplemented() + } + } + } + + private fun handleDetection( + call: MethodCall, + result: MethodChannel.Result, + ) { + val tag = call.argument("model") + val model = getModel(tag!!, result) ?: return + + genericModelManager.isModelDownloaded( + model, + object : GenericModelManager.CheckModelIsDownloadedCallback { + override fun onCheckResult(isDownloaded: Boolean?) { + if (isDownloaded == false) { + result.error("Model Error", "Model has not been downloaded yet", null) + return + } + handleInkDetectionIfModelDownloaded(call, result, model) + } + + override fun onError(e: Exception) { + result.error("Model download check failed", e.toString(), e) + } + }, + ) + } + + private fun handleInkDetectionIfModelDownloaded( + call: MethodCall, + result: MethodChannel.Result, + model: DigitalInkRecognitionModel, + ) { + val id = call.argument("id")!! + val recognizer = + instances.getOrPut(id) { + DigitalInkRecognition.getClient(DigitalInkRecognizerOptions.builder(model).build()) + } + + val inkMap = call.argument>("ink")!! + val strokeList = inkMap["strokes"] as List> + val inkBuilder = Ink.builder() + + for (strokeMap in strokeList) { + val strokeBuilder = Ink.Stroke.builder() + val pointsList = strokeMap["points"] as List> + for (point in pointsList) { + val x = (point["x"] as Double).toFloat() + val y = (point["y"] as Double).toFloat() + val t = + when (val t0 = point["t"]) { + is Int -> t0.toLong() + else -> t0 as Long + } + + strokeBuilder.addPoint(Ink.Point.create(x, y, t)) + } + inkBuilder.addStroke(strokeBuilder.build()) + } + + val ink = inkBuilder.build() + + val contextMap = call.argument>("context") + val context = + contextMap?.let { + val builder = RecognitionContext.builder() + builder.setPreContext((it["preContext"] as? String) ?: "") + + (it["writingArea"] as? Map<*, *>)?.let { areaMap -> + val width = (areaMap["width"] as Double).toFloat() + val height = (areaMap["height"] as Double).toFloat() + builder.setWritingArea(WritingArea(width, height)) + } + builder.build() + } + + val onSuccess = { recognitionResult: RecognitionResult -> process(recognitionResult, result) } + val onFailure = { e: Exception -> result.error("recognition Error", e.toString(), null) } + + if (context != null) { + recognizer + .recognize(ink, context) + .addOnSuccessListener(onSuccess) + .addOnFailureListener(onFailure) + } else { + recognizer + .recognize(ink) + .addOnSuccessListener(onSuccess) + .addOnFailureListener(onFailure) + } + } + + private fun process( + recognitionResult: RecognitionResult, + result: MethodChannel.Result, + ) { + val candidateList = + recognitionResult.candidates.map { candidate -> + mapOf( + "text" to candidate.text, + "score" to (candidate.score?.toDouble() ?: 0.0), + ) + } + result.success(candidateList) + } + + private fun getModel( + tag: String, + result: MethodChannel.Result, + ): DigitalInkRecognitionModel? { + val modelIdentifier = + try { + DigitalInkRecognitionModelIdentifier.fromLanguageTag(tag) + } catch (e: MlKitException) { + result.error("Failed to create model identifier", e.toString(), null) + return null + } + + if (modelIdentifier == null) { + result.error("Model Identifier error", "No model was found", null) + return null + } + + return DigitalInkRecognitionModel.builder(modelIdentifier).build() + } + + private fun manageModel( + call: MethodCall, + result: MethodChannel.Result, + ) { + } + + private fun closeDetector(call: MethodCall) { + val id = call.argument("id") ?: return + instances.remove(id)?.close() + } +} diff --git a/packages/google_mlkit_digital_ink_recognition/android/src/main/kotlin/com/google_mlkit_digital_ink_recognition/GoogleMlKitDigitalInkRecognitionPlugin.kt b/packages/google_mlkit_digital_ink_recognition/android/src/main/kotlin/com/google_mlkit_digital_ink_recognition/GoogleMlKitDigitalInkRecognitionPlugin.kt new file mode 100644 index 00000000..014ef2e9 --- /dev/null +++ b/packages/google_mlkit_digital_ink_recognition/android/src/main/kotlin/com/google_mlkit_digital_ink_recognition/GoogleMlKitDigitalInkRecognitionPlugin.kt @@ -0,0 +1,31 @@ +package com.google_mlkit_digital_ink_recognition + +import io.flutter.embedding.engine.plugins.FlutterPlugin +import io.flutter.plugin.common.MethodCall +import io.flutter.plugin.common.MethodChannel + +class GoogleMlKitDigitalInkRecognitionPlugin : + FlutterPlugin, + MethodChannel.MethodCallHandler { + private lateinit var channel: MethodChannel + + override fun onAttachedToEngine(flutterPluginBinding: FlutterPlugin.FlutterPluginBinding) { + channel = MethodChannel(flutterPluginBinding.binaryMessenger, CHANNEL_NAME) + channel.setMethodCallHandler(DigitalInkRecognizer()) + } + + override fun onDetachedFromEngine(binding: FlutterPlugin.FlutterPluginBinding) { + channel.setMethodCallHandler(null) + } + + override fun onMethodCall( + call: MethodCall, + result: MethodChannel.Result, + ) { + result.notImplemented() + } + + companion object { + private const val CHANNEL_NAME = "google_mlkit_digital_ink_recognition" + } +} diff --git a/packages/google_mlkit_document_scanner/android/build.gradle b/packages/google_mlkit_document_scanner/android/build.gradle index 8264c2d4..4b3c1fe5 100644 --- a/packages/google_mlkit_document_scanner/android/build.gradle +++ b/packages/google_mlkit_document_scanner/android/build.gradle @@ -2,6 +2,7 @@ group = "com.google_mlkit_document_scanner" version = "1.0" buildscript { + ext.kotlin_version = "2.2.20" repositories { google() mavenCentral() @@ -9,6 +10,7 @@ buildscript { dependencies { classpath("com.android.tools.build:gradle:8.13.0") + classpath('org.jetbrains.kotlin:kotlin-gradle-plugin:2.2.20') } } @@ -20,6 +22,7 @@ rootProject.allprojects { } apply plugin: "com.android.library" +apply plugin: "kotlin-android" android { namespace = "com.google_mlkit_document_scanner" @@ -31,6 +34,15 @@ android { targetCompatibility = JavaVersion.VERSION_11 } + kotlinOptions { + jvmTarget = '11' + } + + sourceSets { + main.java.srcDirs += "src/main/kotlin" + } + + defaultConfig { minSdk = 21 } diff --git a/packages/google_mlkit_document_scanner/android/src/main/java/com/google_mlkit_document_scanner/DocumentScanner.java b/packages/google_mlkit_document_scanner/android/src/main/java/com/google_mlkit_document_scanner/DocumentScanner.java deleted file mode 100644 index 2720b220..00000000 --- a/packages/google_mlkit_document_scanner/android/src/main/java/com/google_mlkit_document_scanner/DocumentScanner.java +++ /dev/null @@ -1,198 +0,0 @@ -package com.google_mlkit_document_scanner; - -import android.app.Activity; -import android.content.Intent; -import android.content.IntentSender; - -import androidx.annotation.NonNull; -import androidx.annotation.Nullable; - -import com.google.android.gms.tasks.OnFailureListener; -import com.google.android.gms.tasks.OnSuccessListener; -import com.google.mlkit.vision.documentscanner.GmsDocumentScanner; -import com.google.mlkit.vision.documentscanner.GmsDocumentScanning; -import com.google.mlkit.vision.documentscanner.GmsDocumentScannerOptions; -import com.google.mlkit.vision.documentscanner.GmsDocumentScanningResult; - -import java.util.ArrayList; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.Objects; - -import io.flutter.embedding.engine.plugins.activity.ActivityPluginBinding; -import io.flutter.plugin.common.MethodCall; -import io.flutter.plugin.common.MethodChannel; -import io.flutter.plugin.common.PluginRegistry; - -public class DocumentScanner implements MethodChannel.MethodCallHandler, PluginRegistry.ActivityResultListener { - private static final String START = "vision#startDocumentScanner"; - private static final String CLOSE = "vision#closeDocumentScanner"; - private static final String TAG = "DocumentScanner"; - private final Map instances = new HashMap<>(); - private final ActivityPluginBinding binding; - private MethodChannel.Result pendingResult = null; - final private int START_DOCUMENT_ACTIVITY = 0x362738; - - public DocumentScanner(ActivityPluginBinding binding) { - this.binding = binding; - binding.addActivityResultListener(this); - } - - @Override - public void onMethodCall(@NonNull MethodCall call, @NonNull MethodChannel.Result result) { - String method = call.method; - switch (method) { - case START: - handleScanner(call, result); - break; - case CLOSE: - closeScanner(call); - result.success(null); - break; - default: - result.notImplemented(); - break; - } - } - - private void handleScanner(MethodCall call, final MethodChannel.Result result) { - String id = call.argument("id"); - GmsDocumentScanner scanner = instances.get(id); - pendingResult = result; - - // Create a new scanner instance if it doesn't exist - if (scanner == null) { - Map options = call.argument("options"); - if (options == null) { - result.error(TAG, "Invalid options", null); - return; - } - GmsDocumentScannerOptions scannerOptions = parseOptions(options); - scanner = GmsDocumentScanning.getClient(scannerOptions); - instances.put(id, scanner); - } - - Activity activity = binding.getActivity(); - scanner.getStartScanIntent(activity).addOnSuccessListener(new OnSuccessListener() { - @Override - public void onSuccess(IntentSender intentSender) { - try { - activity.startIntentSenderForResult(intentSender, START_DOCUMENT_ACTIVITY, null, 0, 0, 0); - } catch (IntentSender.SendIntentException e) { - result.error(TAG, "Failed to start document scanner", null); - } - } - }).addOnFailureListener(new OnFailureListener() { - @Override - public void onFailure(@NonNull Exception e) { - result.error(TAG, "Failed to start document scanner", null); - } - }); - } - - // parse scanner options - private GmsDocumentScannerOptions parseOptions(Map options) { - boolean isGalleryImportAllowed = (boolean) options.get("isGalleryImport"); - int pageLimit = (int) options.get("pageLimit"); - List formatStrings = (List) options.get("formats"); - List formatConstants = new ArrayList<>(); - for (String format: formatStrings) { - switch (format) { - case "pdf": - formatConstants.add(GmsDocumentScannerOptions.RESULT_FORMAT_PDF); - break; - case "jpeg": - formatConstants.add(GmsDocumentScannerOptions.RESULT_FORMAT_JPEG); - break; - default: - throw new IllegalArgumentException("Not a format:" + options.get("format")); - } - } - - int mode; - switch ((String) options.get("mode")) { - case "base": - mode = GmsDocumentScannerOptions.SCANNER_MODE_BASE; - break; - case "filter": - mode = GmsDocumentScannerOptions.SCANNER_MODE_BASE_WITH_FILTER; - break; - case "full": - mode = GmsDocumentScannerOptions.SCANNER_MODE_FULL; - break; - default: - throw new IllegalArgumentException("Not a mode:" + options.get("mode")); - } - GmsDocumentScannerOptions.Builder builder = new GmsDocumentScannerOptions - .Builder() - .setGalleryImportAllowed(isGalleryImportAllowed) - .setPageLimit(pageLimit) - .setScannerMode(mode); - - // Set formats - if (!formatConstants.isEmpty()) { - if(formatConstants.size() > 1) { - builder.setResultFormats(formatConstants.get(0), formatConstants.get(1)); - } else { - builder.setResultFormats(formatConstants.get(0)); - } - - } - return builder.build(); - } - - private void closeScanner(MethodCall call) { - String id = call.argument("id"); - GmsDocumentScanner scanner = instances.get(id); - if (scanner == null) return; - instances.remove(id); - } - - @Override - public boolean onActivityResult(int requestCode, int resultCode, @Nullable Intent intent) { - if (requestCode == START_DOCUMENT_ACTIVITY) { - if (resultCode == Activity.RESULT_OK) { - GmsDocumentScanningResult result = GmsDocumentScanningResult.fromActivityResultIntent(intent); - if (result != null) { - handleScanningResult(result); - } - } else if (resultCode == Activity.RESULT_CANCELED) { - pendingResult.error(TAG, "Operation cancelled", null); - } else { - pendingResult.error(TAG, "Unknown Error", null); - } - return true; - } - return false; - } - - private void handleScanningResult(GmsDocumentScanningResult result) { - Map resultMap = new HashMap<>(); - - // Check if the result has a pdf - GmsDocumentScanningResult.Pdf pdf = result.getPdf(); - if (pdf != null) { - Map pdfMap = new HashMap<>(); - pdfMap.put("pageCount", pdf.getPageCount()); - pdfMap.put("uri", pdf.getUri().getPath()); - resultMap.put("pdf", pdfMap); - } else { - resultMap.put("pdf", null); - } - - // Check if the result has a list of pages - List pages = result.getPages(); - if (pages != null && !pages.isEmpty()) { - List imageUris = new ArrayList<>(); - for (GmsDocumentScanningResult.Page page : pages) { - imageUris.add(page.getImageUri().getPath()); - } - resultMap.put("images", imageUris); - } else { - resultMap.put("images", null); - } - - pendingResult.success(resultMap); - } -} diff --git a/packages/google_mlkit_document_scanner/android/src/main/java/com/google_mlkit_document_scanner/GoogleMlKitDocumentScannerPlugin.java b/packages/google_mlkit_document_scanner/android/src/main/java/com/google_mlkit_document_scanner/GoogleMlKitDocumentScannerPlugin.java deleted file mode 100644 index 03a04975..00000000 --- a/packages/google_mlkit_document_scanner/android/src/main/java/com/google_mlkit_document_scanner/GoogleMlKitDocumentScannerPlugin.java +++ /dev/null @@ -1,41 +0,0 @@ -package com.google_mlkit_document_scanner; - -import androidx.annotation.NonNull; - -import io.flutter.embedding.engine.plugins.activity.ActivityAware; -import io.flutter.embedding.engine.plugins.activity.ActivityPluginBinding; -import io.flutter.embedding.engine.plugins.FlutterPlugin; -import io.flutter.plugin.common.MethodChannel; - -public class GoogleMlKitDocumentScannerPlugin implements FlutterPlugin, ActivityAware { - private static final String channelName = "google_mlkit_document_scanner"; - private MethodChannel channel; - - @Override - public void onAttachedToEngine(@NonNull FlutterPluginBinding flutterPluginBinding) { - channel = new MethodChannel(flutterPluginBinding.getBinaryMessenger(), channelName); - } - - @Override - public void onDetachedFromEngine(@NonNull FlutterPluginBinding binding) { - channel.setMethodCallHandler(null); - } - - @Override - public void onAttachedToActivity(@NonNull ActivityPluginBinding binding) { - channel.setMethodCallHandler(new DocumentScanner(binding)); - } - - @Override - public void onReattachedToActivityForConfigChanges(@NonNull ActivityPluginBinding binding) { - channel.setMethodCallHandler(new DocumentScanner(binding)); - } - - @Override - public void onDetachedFromActivityForConfigChanges() { - } - - @Override - public void onDetachedFromActivity() { - } -} diff --git a/packages/google_mlkit_document_scanner/android/src/main/kotlin/com/google_mlkit_document_scanner/DocumentScanner.kt b/packages/google_mlkit_document_scanner/android/src/main/kotlin/com/google_mlkit_document_scanner/DocumentScanner.kt new file mode 100644 index 00000000..f4cc6c8e --- /dev/null +++ b/packages/google_mlkit_document_scanner/android/src/main/kotlin/com/google_mlkit_document_scanner/DocumentScanner.kt @@ -0,0 +1,179 @@ +package com.google_mlkit_document_scanner + +import android.app.Activity +import android.content.Intent +import android.content.IntentSender +import com.google.mlkit.vision.documentscanner.GmsDocumentScanner +import com.google.mlkit.vision.documentscanner.GmsDocumentScannerOptions +import com.google.mlkit.vision.documentscanner.GmsDocumentScanning +import com.google.mlkit.vision.documentscanner.GmsDocumentScanningResult +import io.flutter.embedding.engine.plugins.activity.ActivityPluginBinding +import io.flutter.plugin.common.MethodCall +import io.flutter.plugin.common.MethodChannel +import io.flutter.plugin.common.PluginRegistry + +class DocumentScanner( + private val binding: ActivityPluginBinding, +) : MethodChannel.MethodCallHandler, + PluginRegistry.ActivityResultListener { + private val instances = HashMap() + private var pendingResult: MethodChannel.Result? = null + + companion object { + private const val START = "vision#startDocumentScanner" + private const val CLOSE = "vision#closeDocumentScanner" + private const val TAG = "DocumentScanner" + private const val START_DOCUMENT_ACTIVITY = 0x362738 + } + + init { + binding.addActivityResultListener(this) + } + + override fun onMethodCall( + call: MethodCall, + result: MethodChannel.Result, + ) { + when (call.method) { + START -> { + handleScanner(call, result) + } + + CLOSE -> { + closeScanner(call) + result.success(null) + } + + else -> { + result.notImplemented() + } + } + } + + private fun handleScanner( + call: MethodCall, + result: MethodChannel.Result, + ) { + val id = call.argument("id") + var scanner = instances[id] + pendingResult = result + + if (scanner == null) { + val options = + call.argument>("options") ?: run { + result.error(TAG, "Invalid options", null) + return + } + val scannerOptions = parseOptions(options) + scanner = GmsDocumentScanning.getClient(scannerOptions) + instances[id!!] = scanner + } + + val activity = binding.activity + scanner + .getStartScanIntent(activity) + .addOnSuccessListener { intentSender -> + try { + activity.startIntentSenderForResult(intentSender, START_DOCUMENT_ACTIVITY, null, 0, 0, 0) + } catch (e: IntentSender.SendIntentException) { + result.error(TAG, "Failed to start document scanner", null) + } + }.addOnFailureListener { + result.error(TAG, "Failed to start document scanner", null) + } + } + + private fun parseOptions(options: Map): GmsDocumentScannerOptions { + val isGalleryImportAllowed = options["isGalleryImport"] as Boolean + val pageLimit = options["pageLimit"] as Int + val formatStrings = options["formats"] as List<*> + + val formatConstants = + formatStrings.map { format -> + when (format) { + "pdf" -> GmsDocumentScannerOptions.RESULT_FORMAT_PDF + "jpeg" -> GmsDocumentScannerOptions.RESULT_FORMAT_JPEG + else -> throw IllegalArgumentException("Not a format: $format") + } + } + + val mode = + when (options["mode"] as String) { + "base" -> GmsDocumentScannerOptions.SCANNER_MODE_BASE + "filter" -> GmsDocumentScannerOptions.SCANNER_MODE_BASE_WITH_FILTER + "full" -> GmsDocumentScannerOptions.SCANNER_MODE_FULL + else -> throw IllegalArgumentException("Not a mode: ${options["mode"]}") + } + + val builder = + GmsDocumentScannerOptions + .Builder() + .setGalleryImportAllowed(isGalleryImportAllowed) + .setPageLimit(pageLimit) + .setScannerMode(mode) + + if (formatConstants.isNotEmpty()) { + if (formatConstants.size > 1) { + builder.setResultFormats(formatConstants[0], formatConstants[1]) + } else { + builder.setResultFormats(formatConstants[0]) + } + } + + return builder.build() + } + + private fun closeScanner(call: MethodCall) { + val id = call.argument("id") ?: return + instances.remove(id) + } + + override fun onActivityResult( + requestCode: Int, + resultCode: Int, + intent: Intent?, + ): Boolean { + if (requestCode == START_DOCUMENT_ACTIVITY) { + when (resultCode) { + Activity.RESULT_OK -> { + val result = GmsDocumentScanningResult.fromActivityResultIntent(intent) + result?.let { handleScanningResult(it) } + } + + Activity.RESULT_CANCELED -> { + pendingResult?.error(TAG, "Operation cancelled", null) + } + + else -> { + pendingResult?.error(TAG, "Unknown Error", null) + } + } + return true + } + return false + } + + private fun handleScanningResult(result: GmsDocumentScanningResult) { + val resultMap = HashMap() + + val pdf = result.pdf + if (pdf != null) { + resultMap["pdf"] = + hashMapOf( + "pageCount" to pdf.pageCount, + "uri" to pdf.uri.path, + ) + } else { + resultMap["pdf"] = null + } + + val pages = result.pages + if (!pages.isNullOrEmpty()) { + resultMap["images"] = pages.map { it.imageUri.path } + } else { + resultMap["images"] = null + } + + pendingResult?.success(resultMap) + } +} diff --git a/packages/google_mlkit_document_scanner/android/src/main/kotlin/com/google_mlkit_document_scanner/GoogleMlKitDocumentScannerPlugin.kt b/packages/google_mlkit_document_scanner/android/src/main/kotlin/com/google_mlkit_document_scanner/GoogleMlKitDocumentScannerPlugin.kt new file mode 100644 index 00000000..38d0b106 --- /dev/null +++ b/packages/google_mlkit_document_scanner/android/src/main/kotlin/com/google_mlkit_document_scanner/GoogleMlKitDocumentScannerPlugin.kt @@ -0,0 +1,36 @@ +package com.google_mlkit_document_scanner + +import io.flutter.embedding.engine.plugins.FlutterPlugin +import io.flutter.embedding.engine.plugins.activity.ActivityAware +import io.flutter.embedding.engine.plugins.activity.ActivityPluginBinding +import io.flutter.plugin.common.MethodChannel + +class GoogleMlKitDocumentScannerPlugin : + FlutterPlugin, + ActivityAware { + private lateinit var channel: MethodChannel + + override fun onAttachedToEngine(flutterPluginBinding: FlutterPlugin.FlutterPluginBinding) { + channel = MethodChannel(flutterPluginBinding.binaryMessenger, CHANNEL_NAME) + } + + override fun onDetachedFromEngine(binding: FlutterPlugin.FlutterPluginBinding) { + channel.setMethodCallHandler(null) + } + + override fun onAttachedToActivity(binding: ActivityPluginBinding) { + channel.setMethodCallHandler(DocumentScanner(binding)) + } + + override fun onReattachedToActivityForConfigChanges(binding: ActivityPluginBinding) { + channel.setMethodCallHandler(DocumentScanner(binding)) + } + + override fun onDetachedFromActivityForConfigChanges() {} + + override fun onDetachedFromActivity() {} + + companion object { + private const val CHANNEL_NAME = "google_mlkit_document_scanner" + } +} diff --git a/packages/google_mlkit_entity_extraction/android/build.gradle b/packages/google_mlkit_entity_extraction/android/build.gradle index 258e97c5..f3a6162c 100644 --- a/packages/google_mlkit_entity_extraction/android/build.gradle +++ b/packages/google_mlkit_entity_extraction/android/build.gradle @@ -2,6 +2,7 @@ group = "com.google_mlkit_entity_extraction" version = "1.0" buildscript { + ext.kotlin_version = "2.2.20" repositories { google() mavenCentral() @@ -9,6 +10,7 @@ buildscript { dependencies { classpath("com.android.tools.build:gradle:8.13.0") + classpath('org.jetbrains.kotlin:kotlin-gradle-plugin:2.2.20') } } @@ -20,6 +22,7 @@ rootProject.allprojects { } apply plugin: "com.android.library" +apply plugin: "kotlin-android" android { namespace = "com.google_mlkit_entity_extraction" @@ -31,6 +34,14 @@ android { targetCompatibility = JavaVersion.VERSION_11 } + kotlinOptions { + jvmTarget = '11' + } + + sourceSets { + main.java.srcDirs += "src/main/kotlin" + } + defaultConfig { minSdk = 21 } diff --git a/packages/google_mlkit_entity_extraction/android/gradle/wrapper/gradle-wrapper.properties b/packages/google_mlkit_entity_extraction/android/gradle/wrapper/gradle-wrapper.properties new file mode 100644 index 00000000..128196a7 --- /dev/null +++ b/packages/google_mlkit_entity_extraction/android/gradle/wrapper/gradle-wrapper.properties @@ -0,0 +1,7 @@ +distributionBase=GRADLE_USER_HOME +distributionPath=wrapper/dists +distributionUrl=https\://services.gradle.org/distributions/gradle-9.0-milestone-1-bin.zip +networkTimeout=10000 +validateDistributionUrl=true +zipStoreBase=GRADLE_USER_HOME +zipStorePath=wrapper/dists diff --git a/packages/google_mlkit_entity_extraction/android/src/main/java/com/google_mlkit_entity_extraction/EntityExtractor.java b/packages/google_mlkit_entity_extraction/android/src/main/java/com/google_mlkit_entity_extraction/EntityExtractor.java deleted file mode 100644 index 53ab6a06..00000000 --- a/packages/google_mlkit_entity_extraction/android/src/main/java/com/google_mlkit_entity_extraction/EntityExtractor.java +++ /dev/null @@ -1,196 +0,0 @@ -package com.google_mlkit_entity_extraction; - -import androidx.annotation.NonNull; - -import com.google.mlkit.nl.entityextraction.DateTimeEntity; -import com.google.mlkit.nl.entityextraction.Entity; -import com.google.mlkit.nl.entityextraction.EntityAnnotation; -import com.google.mlkit.nl.entityextraction.EntityExtraction; -import com.google.mlkit.nl.entityextraction.EntityExtractionParams; -import com.google.mlkit.nl.entityextraction.EntityExtractionRemoteModel; -import com.google.mlkit.nl.entityextraction.EntityExtractorOptions; -import com.google.mlkit.nl.entityextraction.FlightNumberEntity; -import com.google.mlkit.nl.entityextraction.IbanEntity; -import com.google.mlkit.nl.entityextraction.IsbnEntity; -import com.google.mlkit.nl.entityextraction.MoneyEntity; -import com.google.mlkit.nl.entityextraction.PaymentCardEntity; -import com.google.mlkit.nl.entityextraction.TrackingNumberEntity; -import com.google_mlkit_commons.GenericModelManager; - -import java.util.ArrayList; -import java.util.HashMap; -import java.util.HashSet; -import java.util.List; -import java.util.Locale; -import java.util.Map; -import java.util.Set; -import java.util.TimeZone; - -import io.flutter.plugin.common.MethodCall; -import io.flutter.plugin.common.MethodChannel; - -public class EntityExtractor implements MethodChannel.MethodCallHandler { - private static final String START = "nlp#startEntityExtractor"; - private static final String CLOSE = "nlp#closeEntityExtractor"; - private static final String MANAGE = "nlp#manageEntityExtractionModels"; - - private final Map instances = new HashMap<>(); - private final GenericModelManager genericModelManager = new GenericModelManager(); - - @Override - public void onMethodCall(@NonNull MethodCall call, @NonNull MethodChannel.Result result) { - String method = call.method; - switch (method) { - case START: - extractEntities(call, result); - break; - case CLOSE: - closeDetector(call); - result.success(null); - break; - case MANAGE: - manageModel(call, result); - break; - default: - result.notImplemented(); - break; - } - } - - private void extractEntities(MethodCall call, final MethodChannel.Result result) { - String text = call.argument("text"); - - String id = call.argument("id"); - com.google.mlkit.nl.entityextraction.EntityExtractor extractor = instances.get(id); - if (extractor == null) { - String language = call.argument("language"); - extractor = EntityExtraction.getClient( - new EntityExtractorOptions.Builder(language) - .build()); - instances.put(id, extractor); - } - final com.google.mlkit.nl.entityextraction.EntityExtractor entityExtractor = extractor; - - Map parameters = call.argument("parameters"); - Set filters = null; - if (parameters.get("filters") != null) { - filters = new HashSet<>((List) parameters.get("filters")); - } - - Locale locale = null; - if (parameters.get("locale") != null) { - locale = new Locale.Builder().setLanguage((String) parameters.get("locale")).build(); - } - - TimeZone timeZone = null; - if (parameters.get("timezone") != null) { - timeZone = TimeZone.getTimeZone((String) parameters.get("timezone")); - } - - Long referenceTime = null; - if (parameters.get("time") != null) { - referenceTime = (Long) parameters.get("time"); - } - - EntityExtractionParams params = new EntityExtractionParams.Builder(text) - .setEntityTypesFilter(filters) - .setPreferredLocale(locale) - .setReferenceTimeZone(timeZone) - .setReferenceTime(referenceTime) - .build(); - - entityExtractor - .downloadModelIfNeeded() - .addOnSuccessListener( - aVoid -> { - // Model downloading succeeded, you can call the extraction API here. - entityExtractor.annotate(params) - .addOnSuccessListener(entityAnnotations -> { - List> allAnnotations = new ArrayList<>(entityAnnotations.size()); - - for (EntityAnnotation entityAnnotation : entityAnnotations) { - Map annotation = new HashMap<>(); - List entities = entityAnnotation.getEntities(); - annotation.put("text", entityAnnotation.getAnnotatedText()); - annotation.put("start", entityAnnotation.getStart()); - annotation.put("end", entityAnnotation.getEnd()); - - List> allEntities = new ArrayList<>(); - for (Entity entity : entities) { - Map entityData = new HashMap<>(); - entityData.put("type", entity.getType()); - entityData.put("raw", entity.toString()); - switch (entity.getType()) { - case Entity.TYPE_ADDRESS: - case Entity.TYPE_URL: - case Entity.TYPE_PHONE: - case Entity.TYPE_EMAIL: - break; - case Entity.TYPE_DATE_TIME: - DateTimeEntity dateTimeEntity = entity.asDateTimeEntity(); - entityData.put("dateTimeGranularity", dateTimeEntity.getDateTimeGranularity() + 1); - entityData.put("timestamp", dateTimeEntity.getTimestampMillis()); - break; - case Entity.TYPE_FLIGHT_NUMBER: - FlightNumberEntity flightNumberEntity = entity.asFlightNumberEntity(); - entityData.put("code", flightNumberEntity.getAirlineCode()); - entityData.put("number", flightNumberEntity.getFlightNumber()); - break; - case Entity.TYPE_IBAN: - IbanEntity ibanEntity = entity.asIbanEntity(); - entityData.put("iban", ibanEntity.getIban()); - entityData.put("code", ibanEntity.getIbanCountryCode()); - break; - case Entity.TYPE_ISBN: - IsbnEntity isbnEntity = entity.asIsbnEntity(); - entityData.put("isbn", isbnEntity.getIsbn()); - break; - case Entity.TYPE_MONEY: - MoneyEntity moneyEntity = entity.asMoneyEntity(); - entityData.put("fraction", moneyEntity.getFractionalPart()); - entityData.put("integer", moneyEntity.getIntegerPart()); - entityData.put("unnormalized", moneyEntity.getUnnormalizedCurrency()); - break; - case Entity.TYPE_PAYMENT_CARD: - PaymentCardEntity paymentCardEntity = entity.asPaymentCardEntity(); - entityData.put("network", paymentCardEntity.getPaymentCardNetwork()); - entityData.put("number", paymentCardEntity.getPaymentCardNumber()); - break; - case Entity.TYPE_TRACKING_NUMBER: - TrackingNumberEntity trackingNumberEntity = entity.asTrackingNumberEntity(); - entityData.put("carrier", trackingNumberEntity.getParcelCarrier()); - entityData.put("number", trackingNumberEntity.getParcelTrackingNumber()); - break; - } - - allEntities.add(entityData); - } - annotation.put("entities", allEntities); - allAnnotations.add(annotation); - } - - result.success(allAnnotations); - }) - .addOnFailureListener(e -> result.error("BarcodeDetectorError", e.toString(), null)); - }) - .addOnFailureListener( - e -> { - // Model could not be downloaded or other internal error. - result.error("Error building extractor", "Model not downloaded", null); - }); - } - - private void closeDetector(MethodCall call) { - String id = call.argument("id"); - com.google.mlkit.nl.entityextraction.EntityExtractor entityExtractor = instances.get(id); - if (entityExtractor == null) return; - entityExtractor.close(); - instances.remove(id); - } - - private void manageModel(MethodCall call, final MethodChannel.Result result) { - EntityExtractionRemoteModel model = - new EntityExtractionRemoteModel.Builder(call.argument("model")).build(); - genericModelManager.manageModel(model, call, result); - } -} diff --git a/packages/google_mlkit_entity_extraction/android/src/main/java/com/google_mlkit_entity_extraction/GoogleMlKitEntityExtractionPlugin.java b/packages/google_mlkit_entity_extraction/android/src/main/java/com/google_mlkit_entity_extraction/GoogleMlKitEntityExtractionPlugin.java deleted file mode 100644 index 27d384c1..00000000 --- a/packages/google_mlkit_entity_extraction/android/src/main/java/com/google_mlkit_entity_extraction/GoogleMlKitEntityExtractionPlugin.java +++ /dev/null @@ -1,22 +0,0 @@ -package com.google_mlkit_entity_extraction; - -import androidx.annotation.NonNull; - -import io.flutter.embedding.engine.plugins.FlutterPlugin; -import io.flutter.plugin.common.MethodChannel; - -public class GoogleMlKitEntityExtractionPlugin implements FlutterPlugin { - private MethodChannel channel; - private static final String channelName = "google_mlkit_entity_extractor"; - - @Override - public void onAttachedToEngine(@NonNull FlutterPluginBinding flutterPluginBinding) { - channel = new MethodChannel(flutterPluginBinding.getBinaryMessenger(), channelName); - channel.setMethodCallHandler(new EntityExtractor()); - } - - @Override - public void onDetachedFromEngine(@NonNull FlutterPluginBinding binding) { - channel.setMethodCallHandler(null); - } -} diff --git a/packages/google_mlkit_entity_extraction/android/src/main/kotlin/com/google_mlkit_entity_extraction/EntityExtractor.kt b/packages/google_mlkit_entity_extraction/android/src/main/kotlin/com/google_mlkit_entity_extraction/EntityExtractor.kt new file mode 100644 index 00000000..fce12d2d --- /dev/null +++ b/packages/google_mlkit_entity_extraction/android/src/main/kotlin/com/google_mlkit_entity_extraction/EntityExtractor.kt @@ -0,0 +1,188 @@ +package com.google_mlkit_entity_extraction + +import com.google.mlkit.nl.entityextraction.Entity +import com.google.mlkit.nl.entityextraction.EntityExtraction +import com.google.mlkit.nl.entityextraction.EntityExtractionParams +import com.google.mlkit.nl.entityextraction.EntityExtractionRemoteModel +import com.google.mlkit.nl.entityextraction.EntityExtractorOptions +import com.google_mlkit_commons.GenericModelManager +import io.flutter.plugin.common.MethodCall +import io.flutter.plugin.common.MethodChannel +import java.util.Locale +import java.util.TimeZone + +class EntityExtractor : MethodChannel.MethodCallHandler { + companion object { + private const val START = "nlp#startEntityExtractor" + private const val CLOSE = "nlp#closeEntityExtractor" + private const val MANAGE = "nlp#manageEntityExtractionModels" + } + + private val instances = HashMap() + private val genericModelManager = GenericModelManager() + + override fun onMethodCall( + call: MethodCall, + result: MethodChannel.Result, + ) { + when (call.method) { + START -> { + extractEntities(call, result) + } + + CLOSE -> { + closeDetector(call) + result.success(null) + } + + MANAGE -> { + manageModel(call, result) + } + + else -> { + result.notImplemented() + } + } + } + + private fun extractEntities( + call: MethodCall, + result: MethodChannel.Result, + ) { + val text = call.argument("text") + val id = call.argument("id") + + var extractor = instances[id] + if (extractor == null) { + val language = call.argument("language")!! + extractor = + EntityExtraction.getClient( + EntityExtractorOptions.Builder(language).build(), + ) + instances[id!!] = extractor + } + + val entityExtractor = extractor + + val parameters = call.argument>("parameters")!! + val filters = (parameters["filters"] as? List)?.toHashSet() + + val locale = + (parameters["locale"] as? String)?.let { + Locale.Builder().setLanguage(it).build() + } + + val timeZone = + (parameters["timezone"] as? String)?.let { + TimeZone.getTimeZone(it) + } + + val referenceTime = parameters["time"] as? Long + + val params = + EntityExtractionParams + .Builder(text!!) + .setEntityTypesFilter(filters) + .setPreferredLocale(locale) + .setReferenceTimeZone(timeZone) + .setReferenceTime(referenceTime) + .build() + + entityExtractor + .downloadModelIfNeeded() + .addOnSuccessListener { + entityExtractor + .annotate(params) + .addOnSuccessListener { entityAnnotations -> + val allAnnotation = + entityAnnotations.map { entityAnnotation -> + val allEntities = + entityAnnotation.entities.map { entity -> + val entityData = HashMap() + entityData["type"] = entity.type + entityData["raw"] = entity.toString() + + when (entity.type) { + Entity.TYPE_ADDRESS, + Entity.TYPE_URL, + Entity.TYPE_PHONE, + Entity.TYPE_EMAIL, + -> { + Unit + } + + Entity.TYPE_DATE_TIME -> { + val dateTimeEntity = entity.asDateTimeEntity() + entityData["dateTimeGranularity"] = dateTimeEntity!!.dateTimeGranularity + 1 + entityData["timestamp"] = dateTimeEntity.timestampMillis + } + + Entity.TYPE_FLIGHT_NUMBER -> { + val flightNumberEntity = entity.asFlightNumberEntity() + entityData["code"] = flightNumberEntity!!.airlineCode + entityData["number"] = flightNumberEntity.flightNumber + } + + Entity.TYPE_IBAN -> { + val ibanEntity = entity.asIbanEntity() + entityData["iban"] = ibanEntity!!.iban + entityData["code"] = ibanEntity.ibanCountryCode + } + + Entity.TYPE_ISBN -> { + entityData["isbn"] = entity.asIsbnEntity()!!.isbn + } + + Entity.TYPE_MONEY -> { + val moneyEntity = entity.asMoneyEntity() + entityData["fraction"] = moneyEntity!!.fractionalPart + entityData["integer"] = moneyEntity.integerPart + entityData["unnormalized"] = moneyEntity.unnormalizedCurrency + } + + Entity.TYPE_PAYMENT_CARD -> { + val paymentCardEntity = entity.asPaymentCardEntity() + entityData["network"] = paymentCardEntity!!.paymentCardNetwork + entityData["number"] = paymentCardEntity.paymentCardNumber + } + + Entity.TYPE_TRACKING_NUMBER -> { + val trackingNumberEntity = entity.asTrackingNumberEntity() + entityData["carrier"] = trackingNumberEntity!!.parcelCarrier + entityData["number"] = trackingNumberEntity.parcelTrackingNumber + } + } + entityData + } + + hashMapOf( + "text" to entityAnnotation.annotatedText, + "start" to entityAnnotation.start, + "end" to entityAnnotation.end, + "entities" to allEntities, + ) + } + result.success(allAnnotation) + }.addOnFailureListener { e -> + result.error("BarcodeDetectorError", e.toString(), null) + } + }.addOnFailureListener { + result.error("Error building extractor", "Model not downloaded", null) + } + } + + private fun closeDetector(call: MethodCall) { + val id = call.argument("id") ?: return + val entityExtractor = instances[id] ?: return + entityExtractor.close() + instances.remove(id) + } + + private fun manageModel( + call: MethodCall, + result: MethodChannel.Result, + ) { + val model = EntityExtractionRemoteModel.Builder(call.argument("model")!!).build() + genericModelManager.manageModel(model, call, result) + } +} diff --git a/packages/google_mlkit_entity_extraction/android/src/main/kotlin/com/google_mlkit_entity_extraction/GoogleMlKitEntityExtractionPlugin.kt b/packages/google_mlkit_entity_extraction/android/src/main/kotlin/com/google_mlkit_entity_extraction/GoogleMlKitEntityExtractionPlugin.kt new file mode 100644 index 00000000..e0377e9a --- /dev/null +++ b/packages/google_mlkit_entity_extraction/android/src/main/kotlin/com/google_mlkit_entity_extraction/GoogleMlKitEntityExtractionPlugin.kt @@ -0,0 +1,21 @@ +package com.google_mlkit_entity_extraction + +import io.flutter.embedding.engine.plugins.FlutterPlugin +import io.flutter.plugin.common.MethodChannel + +class GoogleMlKitEntityExtractionPlugin : FlutterPlugin { + private lateinit var channel: MethodChannel + + companion object { + private const val CHANNEL_NAME = "google_mlkit_entity_extractor" + } + + override fun onAttachedToEngine(flutterPluginBinding: FlutterPlugin.FlutterPluginBinding) { + channel = MethodChannel(flutterPluginBinding.binaryMessenger, CHANNEL_NAME) + channel.setMethodCallHandler(EntityExtractor()) + } + + override fun onDetachedFromEngine(binding: FlutterPlugin.FlutterPluginBinding) { + channel.setMethodCallHandler(null) + } +} diff --git a/packages/google_mlkit_face_detection/android/build.gradle b/packages/google_mlkit_face_detection/android/build.gradle index 76bf75e6..58d14877 100644 --- a/packages/google_mlkit_face_detection/android/build.gradle +++ b/packages/google_mlkit_face_detection/android/build.gradle @@ -2,6 +2,7 @@ group = "com.google_mlkit_face_detection" version = "1.0" buildscript { + ext.kotlin_version = "2.2.20" repositories { google() mavenCentral() @@ -9,6 +10,7 @@ buildscript { dependencies { classpath("com.android.tools.build:gradle:8.13.0") + classpath('org.jetbrains.kotlin:kotlin-gradle-plugin:2.2.20') } } @@ -20,6 +22,7 @@ rootProject.allprojects { } apply plugin: "com.android.library" +apply plugin: "kotlin-android" android { namespace = "com.google_mlkit_face_detection" @@ -31,6 +34,14 @@ android { targetCompatibility = JavaVersion.VERSION_11 } + kotlinOptions { + jvmTarget = '11' + } + + sourceSets { + main.java.srcDirs += "src/main/kotlin" + } + defaultConfig { minSdk = 21 } diff --git a/packages/google_mlkit_face_detection/android/gradle/wrapper/gradle-wrapper.properties b/packages/google_mlkit_face_detection/android/gradle/wrapper/gradle-wrapper.properties new file mode 100644 index 00000000..128196a7 --- /dev/null +++ b/packages/google_mlkit_face_detection/android/gradle/wrapper/gradle-wrapper.properties @@ -0,0 +1,7 @@ +distributionBase=GRADLE_USER_HOME +distributionPath=wrapper/dists +distributionUrl=https\://services.gradle.org/distributions/gradle-9.0-milestone-1-bin.zip +networkTimeout=10000 +validateDistributionUrl=true +zipStoreBase=GRADLE_USER_HOME +zipStorePath=wrapper/dists diff --git a/packages/google_mlkit_face_detection/android/src/main/java/com/google_mlkit_face_detection/FaceDetector.java b/packages/google_mlkit_face_detection/android/src/main/java/com/google_mlkit_face_detection/FaceDetector.java deleted file mode 100644 index f4f17041..00000000 --- a/packages/google_mlkit_face_detection/android/src/main/java/com/google_mlkit_face_detection/FaceDetector.java +++ /dev/null @@ -1,235 +0,0 @@ -package com.google_mlkit_face_detection; - -import android.content.Context; -import android.graphics.PointF; -import android.graphics.Rect; - -import androidx.annotation.NonNull; - -import com.google.mlkit.vision.common.InputImage; -import com.google.mlkit.vision.face.Face; -import com.google.mlkit.vision.face.FaceContour; -import com.google.mlkit.vision.face.FaceDetection; -import com.google.mlkit.vision.face.FaceDetectorOptions; -import com.google.mlkit.vision.face.FaceLandmark; -import com.google_mlkit_commons.InputImageConverter; - -import java.util.ArrayList; -import java.util.HashMap; -import java.util.List; -import java.util.Map; - -import io.flutter.plugin.common.MethodCall; -import io.flutter.plugin.common.MethodChannel; - -class FaceDetector implements MethodChannel.MethodCallHandler { - private static final String START = "vision#startFaceDetector"; - private static final String CLOSE = "vision#closeFaceDetector"; - - private final Context context; - private final Map instances = new HashMap<>(); - - public FaceDetector(Context context) { - this.context = context; - } - - @Override - public void onMethodCall(@NonNull MethodCall call, @NonNull MethodChannel.Result result) { - String method = call.method; - switch (method) { - case START: - handleDetection(call, result); - break; - case CLOSE: - closeDetector(call); - result.success(null); - break; - default: - result.notImplemented(); - break; - } - } - - private void handleDetection(MethodCall call, final MethodChannel.Result result) { - Map imageData = (Map) call.argument("imageData"); - InputImage inputImage = InputImageConverter.getInputImageFromData(imageData, context, result); - if (inputImage == null) - return; - - String id = call.argument("id"); - com.google.mlkit.vision.face.FaceDetector detector = instances.get(id); - if (detector == null) { - Map options = call.argument("options"); - if (options == null) { - result.error("FaceDetectorError", "Invalid options", null); - return; - } - - FaceDetectorOptions detectorOptions = parseOptions(options); - detector = FaceDetection.getClient(detectorOptions); - instances.put(id, detector); - } - - detector.process(inputImage) - .addOnSuccessListener( - visionFaces -> { - List> faces = new ArrayList<>(visionFaces.size()); - for (Face face : visionFaces) { - Map faceData = new HashMap<>(); - - Map frame = new HashMap<>(); - Rect rect = face.getBoundingBox(); - frame.put("left", rect.left); - frame.put("top", rect.top); - frame.put("right", rect.right); - frame.put("bottom", rect.bottom); - faceData.put("rect", frame); - - faceData.put("headEulerAngleX", face.getHeadEulerAngleX()); - faceData.put("headEulerAngleY", face.getHeadEulerAngleY()); - faceData.put("headEulerAngleZ", face.getHeadEulerAngleZ()); - - if (face.getSmilingProbability() != null) { - faceData.put("smilingProbability", face.getSmilingProbability()); - } - - if (face.getLeftEyeOpenProbability() != null) { - faceData.put("leftEyeOpenProbability", face.getLeftEyeOpenProbability()); - } - - if (face.getRightEyeOpenProbability() != null) { - faceData.put("rightEyeOpenProbability", face.getRightEyeOpenProbability()); - } - - if (face.getTrackingId() != null) { - faceData.put("trackingId", face.getTrackingId()); - } - - faceData.put("landmarks", getLandmarkData(face)); - - faceData.put("contours", getContourData(face)); - - faces.add(faceData); - } - - result.success(faces); - }) - .addOnFailureListener( - e -> result.error("FaceDetectorError", e.toString(), null)); - } - - private FaceDetectorOptions parseOptions(Map options) { - int classification = (boolean) options.get("enableClassification") - ? FaceDetectorOptions.CLASSIFICATION_MODE_ALL - : FaceDetectorOptions.CLASSIFICATION_MODE_NONE; - - int landmark = (boolean) options.get("enableLandmarks") - ? FaceDetectorOptions.LANDMARK_MODE_ALL - : FaceDetectorOptions.LANDMARK_MODE_NONE; - - int contours = (boolean) options.get("enableContours") - ? FaceDetectorOptions.CONTOUR_MODE_ALL - : FaceDetectorOptions.CONTOUR_MODE_NONE; - - int mode; - switch ((String) options.get("mode")) { - case "accurate": - mode = FaceDetectorOptions.PERFORMANCE_MODE_ACCURATE; - break; - case "fast": - mode = FaceDetectorOptions.PERFORMANCE_MODE_FAST; - break; - default: - throw new IllegalArgumentException("Not a mode:" + options.get("mode")); - } - - FaceDetectorOptions.Builder builder = new FaceDetectorOptions.Builder() - .setClassificationMode(classification) - .setLandmarkMode(landmark) - .setContourMode(contours) - .setMinFaceSize((float) ((double) options.get("minFaceSize"))) - .setPerformanceMode(mode); - - if ((boolean) options.get("enableTracking")) { - builder.enableTracking(); - } - - return builder.build(); - } - - private Map getLandmarkData(Face face) { - Map landmarks = new HashMap<>(); - - landmarks.put("bottomMouth", landmarkPosition(face, FaceLandmark.MOUTH_BOTTOM)); - landmarks.put("rightMouth", landmarkPosition(face, FaceLandmark.MOUTH_RIGHT)); - landmarks.put("leftMouth", landmarkPosition(face, FaceLandmark.MOUTH_LEFT)); - landmarks.put("rightEye", landmarkPosition(face, FaceLandmark.RIGHT_EYE)); - landmarks.put("leftEye", landmarkPosition(face, FaceLandmark.LEFT_EYE)); - landmarks.put("rightEar", landmarkPosition(face, FaceLandmark.RIGHT_EAR)); - landmarks.put("leftEar", landmarkPosition(face, FaceLandmark.LEFT_EAR)); - landmarks.put("rightCheek", landmarkPosition(face, FaceLandmark.RIGHT_CHEEK)); - landmarks.put("leftCheek", landmarkPosition(face, FaceLandmark.LEFT_CHEEK)); - landmarks.put("noseBase", landmarkPosition(face, FaceLandmark.NOSE_BASE)); - - return landmarks; - } - - private Map> getContourData(Face face) { - Map> contours = new HashMap<>(); - - contours.put("face", contourPosition(face, FaceContour.FACE)); - contours.put( - "leftEyebrowTop", contourPosition(face, FaceContour.LEFT_EYEBROW_TOP)); - contours.put( - "leftEyebrowBottom", contourPosition(face, FaceContour.LEFT_EYEBROW_BOTTOM)); - contours.put( - "rightEyebrowTop", contourPosition(face, FaceContour.RIGHT_EYEBROW_TOP)); - contours.put( - "rightEyebrowBottom", - contourPosition(face, FaceContour.RIGHT_EYEBROW_BOTTOM)); - contours.put("leftEye", contourPosition(face, FaceContour.LEFT_EYE)); - contours.put("rightEye", contourPosition(face, FaceContour.RIGHT_EYE)); - contours.put("upperLipTop", contourPosition(face, FaceContour.UPPER_LIP_TOP)); - contours.put( - "upperLipBottom", contourPosition(face, FaceContour.UPPER_LIP_BOTTOM)); - contours.put("lowerLipTop", contourPosition(face, FaceContour.LOWER_LIP_TOP)); - contours.put( - "lowerLipBottom", contourPosition(face, FaceContour.LOWER_LIP_BOTTOM)); - contours.put("noseBridge", contourPosition(face, FaceContour.NOSE_BRIDGE)); - contours.put("noseBottom", contourPosition(face, FaceContour.NOSE_BOTTOM)); - contours.put("leftCheek", contourPosition(face, FaceContour.LEFT_CHEEK)); - contours.put("rightCheek", contourPosition(face, FaceContour.RIGHT_CHEEK)); - - return contours; - } - - private double[] landmarkPosition(Face face, int landmarkInt) { - FaceLandmark landmark = face.getLandmark(landmarkInt); - if (landmark != null) { - return new double[] { landmark.getPosition().x, landmark.getPosition().y }; - } - return null; - } - - private List contourPosition(Face face, int contourInt) { - FaceContour contour = face.getContour(contourInt); - if (contour != null) { - List contourPoints = contour.getPoints(); - List result = new ArrayList<>(); - for (int i = 0; i < contourPoints.size(); i++) { - result.add(new double[] { contourPoints.get(i).x, contourPoints.get(i).y }); - } - return result; - } - return null; - } - - private void closeDetector(MethodCall call) { - String id = call.argument("id"); - com.google.mlkit.vision.face.FaceDetector detector = instances.get(id); - if (detector == null) - return; - detector.close(); - instances.remove(id); - } -} diff --git a/packages/google_mlkit_face_detection/android/src/main/java/com/google_mlkit_face_detection/GoogleMlKitFaceDetectionPlugin.java b/packages/google_mlkit_face_detection/android/src/main/java/com/google_mlkit_face_detection/GoogleMlKitFaceDetectionPlugin.java deleted file mode 100644 index 0913f993..00000000 --- a/packages/google_mlkit_face_detection/android/src/main/java/com/google_mlkit_face_detection/GoogleMlKitFaceDetectionPlugin.java +++ /dev/null @@ -1,22 +0,0 @@ -package com.google_mlkit_face_detection; - -import androidx.annotation.NonNull; - -import io.flutter.embedding.engine.plugins.FlutterPlugin; -import io.flutter.plugin.common.MethodChannel; - -public class GoogleMlKitFaceDetectionPlugin implements FlutterPlugin { - private MethodChannel channel; - private static final String channelName = "google_mlkit_face_detector"; - - @Override - public void onAttachedToEngine(@NonNull FlutterPluginBinding flutterPluginBinding) { - channel = new MethodChannel(flutterPluginBinding.getBinaryMessenger(), channelName); - channel.setMethodCallHandler(new FaceDetector(flutterPluginBinding.getApplicationContext())); - } - - @Override - public void onDetachedFromEngine(@NonNull FlutterPluginBinding binding) { - channel.setMethodCallHandler(null); - } -} diff --git a/packages/google_mlkit_face_detection/android/src/main/kotlin/com/google_mlkit_face_detection/FaceDetector.kt b/packages/google_mlkit_face_detection/android/src/main/kotlin/com/google_mlkit_face_detection/FaceDetector.kt new file mode 100644 index 00000000..3dba9480 --- /dev/null +++ b/packages/google_mlkit_face_detection/android/src/main/kotlin/com/google_mlkit_face_detection/FaceDetector.kt @@ -0,0 +1,190 @@ +package com.google_mlkit_face_detection + +import android.content.Context +import com.google.mlkit.vision.face.Face +import com.google.mlkit.vision.face.FaceContour +import com.google.mlkit.vision.face.FaceDetection +import com.google.mlkit.vision.face.FaceDetectorOptions +import com.google.mlkit.vision.face.FaceLandmark +import com.google_mlkit_commons.InputImageConverter +import io.flutter.plugin.common.MethodCall +import io.flutter.plugin.common.MethodChannel + +class FaceDetector( + private val context: Context, +) : MethodChannel.MethodCallHandler { + private val instances = mutableMapOf() + + companion object { + private const val START = "vision#startFaceDetector" + private const val CLOSE = "vision#closeFaceDetector" + } + + override fun onMethodCall( + call: MethodCall, + result: MethodChannel.Result, + ) { + when (call.method) { + START -> { + handleDetection(call, result) + } + + CLOSE -> { + closeDetector(call) + result.success(null) + } + + else -> { + result.notImplemented() + } + } + } + + private fun handleDetection( + call: MethodCall, + result: MethodChannel.Result, + ) { + val imageData = + call.argument>("imageData") ?: run { + result.error("FaceDetectorError", "imageData is null", null) + return + } + val inputImage = InputImageConverter.getInputImageFromData(imageData, context, result) ?: return + + val id = call.argument("id")!! + val detector = + instances.getOrPut(id) { + val options = + call.argument>("options") + ?: return result.error("FaceDetectorError", "Invalid options", null) + FaceDetection.getClient(parseOptions(options)) + } + + detector + .process(inputImage) + .addOnSuccessListener { visionFaces -> + val faces = + visionFaces.map { face -> + buildMap { + val rect = face.boundingBox + put( + "rect", + mapOf( + "left" to rect.left, + "top" to rect.top, + "right" to rect.right, + "bottom" to rect.bottom, + ), + ) + put("headEulerAngleX", face.headEulerAngleX) + put("headEulerAngleY", face.headEulerAngleY) + put("headEulerAngleZ", face.headEulerAngleZ) + face.smilingProbability?.let { put("smilingProbability", it) } + face.leftEyeOpenProbability?.let { put("leftEyeOpenProbability", it) } + face.rightEyeOpenProbability?.let { put("rightEyeOpenProbability", it) } + face.trackingId?.let { put("trackingId", it) } + put("landmarks", getLandmarkData(face)) + put("contours", getContourData(face)) + } + } + result.success(faces) + }.addOnFailureListener { e -> + result.error("FaceDetectorError", e.toString(), null) + } + } + + private fun parseOptions(options: Map): FaceDetectorOptions { + val classification = + if (options["enableClassification"] as Boolean) { + FaceDetectorOptions.CLASSIFICATION_MODE_ALL + } else { + FaceDetectorOptions.CLASSIFICATION_MODE_NONE + } + + val landmark = + if (options["enableLandmarks"] as Boolean) { + FaceDetectorOptions.LANDMARK_MODE_ALL + } else { + FaceDetectorOptions.LANDMARK_MODE_NONE + } + + val contours = + if (options["enableContours"] as Boolean) { + FaceDetectorOptions.CONTOUR_MODE_ALL + } else { + FaceDetectorOptions.CONTOUR_MODE_NONE + } + + val mode = + when (options["mode"] as String) { + "accurate" -> FaceDetectorOptions.PERFORMANCE_MODE_ACCURATE + "fast" -> FaceDetectorOptions.PERFORMANCE_MODE_FAST + else -> throw IllegalArgumentException("Not a mode: ${options["mode"]}") + } + + return FaceDetectorOptions + .Builder() + .setClassificationMode(classification) + .setLandmarkMode(landmark) + .setContourMode(contours) + .setMinFaceSize((options["minFaceSize"] as Double).toFloat()) + .setPerformanceMode(mode) + .apply { + if (options["enableTracking"] as Boolean) enableTracking() + }.build() + } + + private fun getLandmarkData(face: Face): Map = + mapOf( + "bottomMouth" to landmarkPosition(face, FaceLandmark.MOUTH_BOTTOM), + "rightMouth" to landmarkPosition(face, FaceLandmark.MOUTH_RIGHT), + "leftMouth" to landmarkPosition(face, FaceLandmark.MOUTH_LEFT), + "rightEye" to landmarkPosition(face, FaceLandmark.RIGHT_EYE), + "leftEye" to landmarkPosition(face, FaceLandmark.LEFT_EYE), + "rightEar" to landmarkPosition(face, FaceLandmark.RIGHT_EAR), + "leftEar" to landmarkPosition(face, FaceLandmark.LEFT_EAR), + "rightCheek" to landmarkPosition(face, FaceLandmark.RIGHT_CHEEK), + "leftCheek" to landmarkPosition(face, FaceLandmark.LEFT_CHEEK), + "noseBase" to landmarkPosition(face, FaceLandmark.NOSE_BASE), + ) + + private fun getContourData(face: Face): Map?> = + mapOf( + "face" to contourPosition(face, FaceContour.FACE), + "leftEyebrowTop" to contourPosition(face, FaceContour.LEFT_EYEBROW_TOP), + "leftEyebrowBottom" to contourPosition(face, FaceContour.LEFT_EYEBROW_BOTTOM), + "rightEyebrowTop" to contourPosition(face, FaceContour.RIGHT_EYEBROW_TOP), + "rightEyebrowBottom" to contourPosition(face, FaceContour.RIGHT_EYEBROW_BOTTOM), + "leftEye" to contourPosition(face, FaceContour.LEFT_EYE), + "rightEye" to contourPosition(face, FaceContour.RIGHT_EYE), + "upperLipTop" to contourPosition(face, FaceContour.UPPER_LIP_TOP), + "upperLipBottom" to contourPosition(face, FaceContour.UPPER_LIP_BOTTOM), + "lowerLipTop" to contourPosition(face, FaceContour.LOWER_LIP_TOP), + "lowerLipBottom" to contourPosition(face, FaceContour.LOWER_LIP_BOTTOM), + "noseBridge" to contourPosition(face, FaceContour.NOSE_BRIDGE), + "noseBottom" to contourPosition(face, FaceContour.NOSE_BOTTOM), + "leftCheek" to contourPosition(face, FaceContour.LEFT_CHEEK), + "rightCheek" to contourPosition(face, FaceContour.RIGHT_CHEEK), + ) + + private fun landmarkPosition( + face: Face, + landmarkInt: Int, + ): DoubleArray? = + face.getLandmark(landmarkInt)?.position?.let { + doubleArrayOf(it.x.toDouble(), it.y.toDouble()) + } + + private fun contourPosition( + face: Face, + contourInt: Int, + ): List? = + face.getContour(contourInt)?.points?.map { + doubleArrayOf(it.x.toDouble(), it.y.toDouble()) + } + + private fun closeDetector(call: MethodCall) { + val id = call.argument("id") ?: return + instances.remove(id)?.close() + } +} diff --git a/packages/google_mlkit_face_detection/android/src/main/kotlin/com/google_mlkit_face_detection/GoogleMlKitFaceDetectionPlugin.kt b/packages/google_mlkit_face_detection/android/src/main/kotlin/com/google_mlkit_face_detection/GoogleMlKitFaceDetectionPlugin.kt new file mode 100644 index 00000000..29ac1dd4 --- /dev/null +++ b/packages/google_mlkit_face_detection/android/src/main/kotlin/com/google_mlkit_face_detection/GoogleMlKitFaceDetectionPlugin.kt @@ -0,0 +1,21 @@ +package com.google_mlkit_face_detection + +import io.flutter.embedding.engine.plugins.FlutterPlugin +import io.flutter.plugin.common.MethodChannel + +class GoogleMlKitFaceDetectionPlugin : FlutterPlugin { + private lateinit var channel: MethodChannel + + companion object { + private const val CHANNEL_NAME = "google_mlkit_face_detector" + } + + override fun onAttachedToEngine(flutterPluginBinding: FlutterPlugin.FlutterPluginBinding) { + channel = MethodChannel(flutterPluginBinding.binaryMessenger, CHANNEL_NAME) + channel.setMethodCallHandler(FaceDetector(flutterPluginBinding.applicationContext)) + } + + override fun onDetachedFromEngine(binding: FlutterPlugin.FlutterPluginBinding) { + channel.setMethodCallHandler(null) + } +} diff --git a/packages/google_mlkit_face_mesh_detection/android/build.gradle b/packages/google_mlkit_face_mesh_detection/android/build.gradle index e7016832..b447506b 100644 --- a/packages/google_mlkit_face_mesh_detection/android/build.gradle +++ b/packages/google_mlkit_face_mesh_detection/android/build.gradle @@ -2,6 +2,7 @@ group = "com.google_mlkit_face_mesh_detection" version = "1.0" buildscript { + ext.kotlin_version = "2.2.20" repositories { google() mavenCentral() @@ -9,6 +10,7 @@ buildscript { dependencies { classpath("com.android.tools.build:gradle:8.13.0") + classpath('org.jetbrains.kotlin:kotlin-gradle-plugin:2.2.20') } } @@ -20,6 +22,7 @@ rootProject.allprojects { } apply plugin: "com.android.library" +apply plugin: "kotlin-android" android { namespace = "com.google_mlkit_face_mesh_detection" @@ -31,6 +34,14 @@ android { targetCompatibility = JavaVersion.VERSION_11 } + kotlinOptions { + jvmTarget = '11' + } + + sourceSets { + main.java.srcDirs += "src/main/kotlin" + } + defaultConfig { minSdk = 21 } diff --git a/packages/google_mlkit_face_mesh_detection/android/gradle/wrapper/gradle-wrapper.properties b/packages/google_mlkit_face_mesh_detection/android/gradle/wrapper/gradle-wrapper.properties new file mode 100644 index 00000000..128196a7 --- /dev/null +++ b/packages/google_mlkit_face_mesh_detection/android/gradle/wrapper/gradle-wrapper.properties @@ -0,0 +1,7 @@ +distributionBase=GRADLE_USER_HOME +distributionPath=wrapper/dists +distributionUrl=https\://services.gradle.org/distributions/gradle-9.0-milestone-1-bin.zip +networkTimeout=10000 +validateDistributionUrl=true +zipStoreBase=GRADLE_USER_HOME +zipStorePath=wrapper/dists diff --git a/packages/google_mlkit_face_mesh_detection/android/src/main/java/com/google_mlkit_face_mesh_detection/FaceMeshDetector.java b/packages/google_mlkit_face_mesh_detection/android/src/main/java/com/google_mlkit_face_mesh_detection/FaceMeshDetector.java deleted file mode 100644 index 0d39560f..00000000 --- a/packages/google_mlkit_face_mesh_detection/android/src/main/java/com/google_mlkit_face_mesh_detection/FaceMeshDetector.java +++ /dev/null @@ -1,159 +0,0 @@ -package com.google_mlkit_face_mesh_detection; - -import android.content.Context; -import android.graphics.Rect; - -import androidx.annotation.NonNull; - -import com.google.mlkit.vision.common.InputImage; -import com.google.mlkit.vision.common.Triangle; -import com.google.mlkit.vision.facemesh.FaceMesh; -import com.google.mlkit.vision.facemesh.FaceMeshDetection; -import com.google.mlkit.vision.facemesh.FaceMeshDetectorOptions; -import com.google.mlkit.vision.facemesh.FaceMeshPoint; -import com.google_mlkit_commons.InputImageConverter; - -import java.util.ArrayList; -import java.util.HashMap; -import java.util.List; -import java.util.Map; - -import io.flutter.plugin.common.MethodCall; -import io.flutter.plugin.common.MethodChannel; - -class FaceMeshDetector implements MethodChannel.MethodCallHandler { - private static final String START = "vision#startFaceMeshDetector"; - private static final String CLOSE = "vision#closeFaceMeshDetector"; - - private final Context context; - private final Map instances = new HashMap<>(); - - public FaceMeshDetector(Context context) { - this.context = context; - } - - @Override - public void onMethodCall(@NonNull MethodCall call, @NonNull MethodChannel.Result result) { - String method = call.method; - switch (method) { - case START: - handleDetection(call, result); - break; - case CLOSE: - closeDetector(call); - result.success(null); - break; - default: - result.notImplemented(); - break; - } - } - - private void handleDetection(MethodCall call, final MethodChannel.Result result) { - Map imageData = (Map) call.argument("imageData"); - InputImage inputImage = InputImageConverter.getInputImageFromData(imageData, context, result); - if (inputImage == null) return; - - String id = call.argument("id"); - com.google.mlkit.vision.facemesh.FaceMeshDetector detector = instances.get(id); - if (detector == null) { - int option = call.argument("option"); - switch (option) { - case FaceMeshDetectorOptions.BOUNDING_BOX_ONLY: - detector = FaceMeshDetection.getClient( - new FaceMeshDetectorOptions.Builder() - .setUseCase(FaceMeshDetectorOptions.BOUNDING_BOX_ONLY) - .build() - ); - break; - - case FaceMeshDetectorOptions.FACE_MESH: - detector = FaceMeshDetection.getClient(); - - break; - - default: - result.error("FaceMeshDetectorError", "Invalid options", null); - return; - } - - instances.put(id, detector); - } - - detector.process(inputImage) - .addOnSuccessListener( - visionMeshes -> { - List> faceMeshes = new ArrayList<>(visionMeshes.size()); - for (FaceMesh mesh : visionMeshes) { - Map meshData = new HashMap<>(); - - Map frame = new HashMap<>(); - Rect rect = mesh.getBoundingBox(); - frame.put("left", rect.left); - frame.put("top", rect.top); - frame.put("right", rect.right); - frame.put("bottom", rect.bottom); - meshData.put("rect", frame); - - meshData.put("points", pointsToList(mesh.getAllPoints())); - - List>> triangles = new ArrayList<>(); - for (Triangle triangle : mesh.getAllTriangles()) { - triangles.add(pointsToList(triangle.getAllPoints())); - } - meshData.put("triangles", triangles); - - int[] types = { - FaceMesh.FACE_OVAL, - FaceMesh.LEFT_EYEBROW_TOP, - FaceMesh.LEFT_EYEBROW_BOTTOM, - FaceMesh.RIGHT_EYEBROW_TOP, - FaceMesh.RIGHT_EYEBROW_BOTTOM, - FaceMesh.LEFT_EYE, - FaceMesh.RIGHT_EYE, - FaceMesh.UPPER_LIP_TOP, - FaceMesh.UPPER_LIP_BOTTOM, - FaceMesh.LOWER_LIP_TOP, - FaceMesh.LOWER_LIP_BOTTOM, - FaceMesh.NOSE_BRIDGE - }; - Map>> contours = new HashMap<>(); - for (int type : types) { - contours.put(type - 1, pointsToList(mesh.getPoints(type))); - } - meshData.put("contours", contours); - - faceMeshes.add(meshData); - } - - result.success(faceMeshes); - }) - .addOnFailureListener( - e -> result.error("FaceMeshDetectorError", e.toString(), null)); - } - - private List> pointsToList(List points) { - List> list = new ArrayList<>(); - for (FaceMeshPoint point : points) { - list.add(pointToMap(point)); - } - return list; - } - - private Map pointToMap(FaceMeshPoint point) { - Map pointMap = new HashMap<>(); - pointMap.put("index", point.getIndex()); - pointMap.put("x", point.getPosition().getX()); - pointMap.put("y", point.getPosition().getY()); - pointMap.put("z", point.getPosition().getZ()); - return pointMap; - } - - private void closeDetector(MethodCall call) { - String id = call.argument("id"); - com.google.mlkit.vision.facemesh.FaceMeshDetector detector = instances.get(id); - if (detector == null) return; - detector.close(); - instances.remove(id); - } -} diff --git a/packages/google_mlkit_face_mesh_detection/android/src/main/java/com/google_mlkit_face_mesh_detection/GoogleMlKitFaceMeshDetectionPlugin.java b/packages/google_mlkit_face_mesh_detection/android/src/main/java/com/google_mlkit_face_mesh_detection/GoogleMlKitFaceMeshDetectionPlugin.java deleted file mode 100644 index 56bedce1..00000000 --- a/packages/google_mlkit_face_mesh_detection/android/src/main/java/com/google_mlkit_face_mesh_detection/GoogleMlKitFaceMeshDetectionPlugin.java +++ /dev/null @@ -1,22 +0,0 @@ -package com.google_mlkit_face_mesh_detection; - -import androidx.annotation.NonNull; - -import io.flutter.embedding.engine.plugins.FlutterPlugin; -import io.flutter.plugin.common.MethodChannel; - -public class GoogleMlKitFaceMeshDetectionPlugin implements FlutterPlugin { - private MethodChannel channel; - private static final String channelName = "google_mlkit_face_mesh_detector"; - - @Override - public void onAttachedToEngine(@NonNull FlutterPluginBinding flutterPluginBinding) { - channel = new MethodChannel(flutterPluginBinding.getBinaryMessenger(), channelName); - channel.setMethodCallHandler(new FaceMeshDetector(flutterPluginBinding.getApplicationContext())); - } - - @Override - public void onDetachedFromEngine(@NonNull FlutterPluginBinding binding) { - channel.setMethodCallHandler(null); - } -} diff --git a/packages/google_mlkit_face_mesh_detection/android/src/main/kotlin/com/google_mlkit_face_mesh_detection/FaceMeshDetector.kt b/packages/google_mlkit_face_mesh_detection/android/src/main/kotlin/com/google_mlkit_face_mesh_detection/FaceMeshDetector.kt new file mode 100644 index 00000000..9fba504b --- /dev/null +++ b/packages/google_mlkit_face_mesh_detection/android/src/main/kotlin/com/google_mlkit_face_mesh_detection/FaceMeshDetector.kt @@ -0,0 +1,149 @@ +package com.google_mlkit_face_mesh_detection + +import android.content.Context +import com.google.mlkit.vision.facemesh.FaceMesh +import com.google.mlkit.vision.facemesh.FaceMeshDetection +import com.google.mlkit.vision.facemesh.FaceMeshDetectorOptions +import com.google.mlkit.vision.facemesh.FaceMeshPoint +import com.google_mlkit_commons.InputImageConverter +import io.flutter.plugin.common.MethodCall +import io.flutter.plugin.common.MethodChannel + +class FaceMeshDetector( + private val context: Context, +) : MethodChannel.MethodCallHandler { + private val instances = mutableMapOf() + + companion object { + private const val START = "vision#startFaceMeshDetector" + private const val CLOSE = "vision#closeFaceMeshDetector" + } + + override fun onMethodCall( + call: MethodCall, + result: MethodChannel.Result, + ) { + when (call.method) { + START -> { + handleDetection(call, result) + } + + CLOSE -> { + closeDetector(call) + result.success(null) + } + + else -> { + result.notImplemented() + } + } + } + + private fun handleDetection( + call: MethodCall, + result: MethodChannel.Result, + ) { + val imageData = + call.argument>("imageData") ?: run { + result.error("FaceMeshDetectorError", "imageData is null", null) + return + } + + val inputImage = InputImageConverter.getInputImageFromData(imageData, context, result) ?: return + + val id = call.argument("id")!! + var detector = instances[id] + + if (detector == null) { + detector = + when (val option = call.argument("option")) { + FaceMeshDetectorOptions.BOUNDING_BOX_ONLY -> { + FaceMeshDetection.getClient( + FaceMeshDetectorOptions + .Builder() + .setUseCase(FaceMeshDetectorOptions.BOUNDING_BOX_ONLY) + .build(), + ) + } + + FaceMeshDetectorOptions.FACE_MESH -> { + FaceMeshDetection.getClient() + } + + else -> { + result.error("FaceMeshDetectorError", "Invalid options", null) + return + } + } + instances[id] = detector + } + + detector + .process(inputImage) + .addOnSuccessListener { visionMeshes -> + val faceMeshes = + visionMeshes.map { mesh -> + val rect = mesh.boundingBox + val frame = + mapOf( + "left" to rect.left, + "top" to rect.top, + "right" to rect.right, + "bottom" to rect.bottom, + ) + + val triangles = + mesh.allTriangles.map { triangle -> + pointsToList(triangle.allPoints) + } + + val types = + intArrayOf( + FaceMesh.FACE_OVAL, + FaceMesh.LEFT_EYEBROW_TOP, + FaceMesh.LEFT_EYEBROW_BOTTOM, + FaceMesh.RIGHT_EYEBROW_TOP, + FaceMesh.RIGHT_EYEBROW_BOTTOM, + FaceMesh.LEFT_EYE, + FaceMesh.RIGHT_EYE, + FaceMesh.UPPER_LIP_TOP, + FaceMesh.UPPER_LIP_BOTTOM, + FaceMesh.LOWER_LIP_TOP, + FaceMesh.LOWER_LIP_BOTTOM, + FaceMesh.NOSE_BRIDGE, + ) + + val contours = + types.associate { type -> + (type - 1) to pointsToList(mesh.getPoints(type)) + } + + mapOf( + "rect" to frame, + "points" to pointsToList(mesh.allPoints), + "triangles" to triangles, + "contours" to contours, + ) + } + result.success(faceMeshes) + }.addOnFailureListener { e -> + result.error("FaceMeshDetectorError", e.toString(), null) + } + } + + private fun pointsToList(points: List): List> = points.map { pointToMap(it) } + + private fun pointToMap(point: FaceMeshPoint): Map = + mapOf( + "index" to point.index, + "x" to point.position.x, + "y" to point.position.y, + "z" to point.position.z, + ) + + private fun closeDetector(call: MethodCall) { + val id = call.argument("id") ?: return + instances[id]?.close() + instances.remove(id) + } +} diff --git a/packages/google_mlkit_face_mesh_detection/android/src/main/kotlin/com/google_mlkit_face_mesh_detection/GoogleMlKitFaceMeshDetectionPlugin.kt b/packages/google_mlkit_face_mesh_detection/android/src/main/kotlin/com/google_mlkit_face_mesh_detection/GoogleMlKitFaceMeshDetectionPlugin.kt new file mode 100644 index 00000000..a6847888 --- /dev/null +++ b/packages/google_mlkit_face_mesh_detection/android/src/main/kotlin/com/google_mlkit_face_mesh_detection/GoogleMlKitFaceMeshDetectionPlugin.kt @@ -0,0 +1,21 @@ +package com.google_mlkit_face_mesh_detection + +import io.flutter.embedding.engine.plugins.FlutterPlugin +import io.flutter.plugin.common.MethodChannel + +class GoogleMlKitFaceMeshDetectionPlugin : FlutterPlugin { + private lateinit var channel: MethodChannel + + companion object { + private const val CHANNEL_NAME = "google_mlkit_face_mesh_detector" + } + + override fun onAttachedToEngine(flutterPluginBinding: FlutterPlugin.FlutterPluginBinding) { + channel = MethodChannel(flutterPluginBinding.binaryMessenger, CHANNEL_NAME) + channel.setMethodCallHandler(FaceMeshDetector(flutterPluginBinding.applicationContext)) + } + + override fun onDetachedFromEngine(binding: FlutterPlugin.FlutterPluginBinding) { + channel.setMethodCallHandler(null) + } +} diff --git a/packages/google_mlkit_genai_image_description/android/build.gradle b/packages/google_mlkit_genai_image_description/android/build.gradle index 29c6726b..9d5d7d8d 100644 --- a/packages/google_mlkit_genai_image_description/android/build.gradle +++ b/packages/google_mlkit_genai_image_description/android/build.gradle @@ -2,6 +2,7 @@ group = "com.google_mlkit_genai_image_description" version = "1.0" buildscript { + ext.kotlin_version = "2.2.20" repositories { google() mavenCentral() @@ -9,6 +10,7 @@ buildscript { dependencies { classpath("com.android.tools.build:gradle:8.13.0") + classpath('org.jetbrains.kotlin:kotlin-gradle-plugin:2.2.20') } } @@ -20,6 +22,7 @@ rootProject.allprojects { } apply plugin: "com.android.library" +apply plugin: "kotlin-android" android { namespace = "com.google_mlkit_genai_image_description" @@ -31,6 +34,14 @@ android { targetCompatibility = JavaVersion.VERSION_11 } + kotlinOptions { + jvmTarget = '11' + } + + sourceSets { + main.java.srcDirs += "src/main/kotlin" + } + defaultConfig { minSdk = 26 } diff --git a/packages/google_mlkit_genai_image_description/android/gradle/wrapper/gradle-wrapper.properties b/packages/google_mlkit_genai_image_description/android/gradle/wrapper/gradle-wrapper.properties new file mode 100644 index 00000000..128196a7 --- /dev/null +++ b/packages/google_mlkit_genai_image_description/android/gradle/wrapper/gradle-wrapper.properties @@ -0,0 +1,7 @@ +distributionBase=GRADLE_USER_HOME +distributionPath=wrapper/dists +distributionUrl=https\://services.gradle.org/distributions/gradle-9.0-milestone-1-bin.zip +networkTimeout=10000 +validateDistributionUrl=true +zipStoreBase=GRADLE_USER_HOME +zipStorePath=wrapper/dists diff --git a/packages/google_mlkit_genai_image_description/android/src/main/java/com/google_mlkit_genai_image_description/GoogleMlKitGenaiImageDescriptionPlugin.java b/packages/google_mlkit_genai_image_description/android/src/main/java/com/google_mlkit_genai_image_description/GoogleMlKitGenaiImageDescriptionPlugin.java deleted file mode 100644 index e7b3e337..00000000 --- a/packages/google_mlkit_genai_image_description/android/src/main/java/com/google_mlkit_genai_image_description/GoogleMlKitGenaiImageDescriptionPlugin.java +++ /dev/null @@ -1,22 +0,0 @@ -package com.google_mlkit_genai_image_description; - -import androidx.annotation.NonNull; - -import io.flutter.embedding.engine.plugins.FlutterPlugin; -import io.flutter.plugin.common.MethodChannel; - -public class GoogleMlKitGenaiImageDescriptionPlugin implements FlutterPlugin { - private MethodChannel channel; - private static final String channelName = "google_mlkit_genai_image_description"; - - @Override - public void onAttachedToEngine(@NonNull FlutterPluginBinding flutterPluginBinding) { - channel = new MethodChannel(flutterPluginBinding.getBinaryMessenger(), channelName); - channel.setMethodCallHandler(new ImageDescriber(flutterPluginBinding.getApplicationContext())); - } - - @Override - public void onDetachedFromEngine(@NonNull FlutterPluginBinding binding) { - channel.setMethodCallHandler(null); - } -} diff --git a/packages/google_mlkit_genai_image_description/android/src/main/java/com/google_mlkit_genai_image_description/ImageDescriber.java b/packages/google_mlkit_genai_image_description/android/src/main/java/com/google_mlkit_genai_image_description/ImageDescriber.java deleted file mode 100644 index 49494322..00000000 --- a/packages/google_mlkit_genai_image_description/android/src/main/java/com/google_mlkit_genai_image_description/ImageDescriber.java +++ /dev/null @@ -1,274 +0,0 @@ -package com.google_mlkit_genai_image_description; - -import android.content.Context; -import android.graphics.Bitmap; -import android.graphics.BitmapFactory; -import android.net.Uri; -import android.util.Log; - -import androidx.annotation.NonNull; - -import com.google.mlkit.genai.imagedescription.ImageDescription; -import com.google.mlkit.genai.imagedescription.ImageDescriptionRequest; -import com.google.mlkit.genai.imagedescription.ImageDescriberOptions; - -import com.google.common.util.concurrent.Futures; -import com.google.common.util.concurrent.ListenableFuture; -import com.google.common.util.concurrent.FutureCallback; - -import java.io.File; -import java.io.IOException; -import java.util.HashMap; -import java.util.Map; -import java.util.Objects; -import java.util.concurrent.ExecutionException; -import java.util.concurrent.Executor; -import java.util.concurrent.Executors; - -import io.flutter.plugin.common.MethodCall; -import io.flutter.plugin.common.MethodChannel; - -public class ImageDescriber implements MethodChannel.MethodCallHandler { - private static final String CHECK_FEATURE_STATUS = "genai#checkFeatureStatus"; - private static final String DOWNLOAD_FEATURE = "genai#downloadFeature"; - private static final String RUN_INFERENCE = "genai#runInference"; - private static final String RUN_INFERENCE_STREAMING = "genai#runInferenceStreaming"; - private static final String CLOSE = "genai#closeImageDescriber"; - - private final Context context; - private final Map instances = new HashMap<>(); - private final Executor executor = Executors.newSingleThreadExecutor(); - - public ImageDescriber(Context context) { - this.context = context; - } - - @Override - public void onMethodCall(@NonNull MethodCall call, @NonNull MethodChannel.Result result) { - String method = call.method; - switch (method) { - case CHECK_FEATURE_STATUS: - checkFeatureStatus(call, result); - break; - case DOWNLOAD_FEATURE: - downloadFeature(call, result); - break; - case RUN_INFERENCE: - runInference(call, result); - break; - case RUN_INFERENCE_STREAMING: - runInferenceStreaming(call, result); - break; - case CLOSE: - closeImageDescriber(call); - result.success(null); - break; - default: - result.notImplemented(); - break; - } - } - - private com.google.mlkit.genai.imagedescription.ImageDescriber initialize(MethodCall call) { - ImageDescriberOptions options = ImageDescriberOptions.builder(context).build(); - return ImageDescription.getClient(options); - } - - private void checkFeatureStatus(MethodCall call, MethodChannel.Result result) { - String id = call.argument("id"); - com.google.mlkit.genai.imagedescription.ImageDescriber imageDescriber = instances.get(id); - if (imageDescriber == null) { - imageDescriber = initialize(call); - instances.put(id, imageDescriber); - } - - ListenableFuture future = imageDescriber.checkFeatureStatus(); - Futures.addCallback(future, new FutureCallback() { - @Override - public void onSuccess(Integer status) { - int statusValue; - if (status == com.google.mlkit.genai.common.FeatureStatus.UNAVAILABLE) { - statusValue = 0; - } else if (status == com.google.mlkit.genai.common.FeatureStatus.DOWNLOADABLE) { - statusValue = 1; - } else if (status == com.google.mlkit.genai.common.FeatureStatus.DOWNLOADING) { - statusValue = 2; - } else if (status == com.google.mlkit.genai.common.FeatureStatus.AVAILABLE) { - statusValue = 3; - } else { - statusValue = 0; - } - result.success(statusValue); - } - - @Override - public void onFailure(Throwable e) { - result.error("ImageDescriberError", e.toString(), null); - } - }, executor); - } - - private void downloadFeature(MethodCall call, MethodChannel.Result result) { - String id = call.argument("id"); - com.google.mlkit.genai.imagedescription.ImageDescriber imageDescriber = instances.get(id); - if (imageDescriber == null) { - imageDescriber = initialize(call); - instances.put(id, imageDescriber); - } - - imageDescriber.downloadFeature(new com.google.mlkit.genai.common.DownloadCallback() { - @Override - public void onDownloadStarted(long bytesToDownload) { - // Handle download started - } - - @Override - public void onDownloadFailed(com.google.mlkit.genai.common.GenAiException e) { - result.error("DownloadError", e.toString(), null); - } - - @Override - public void onDownloadProgress(long totalBytesDownloaded) { - // Handle download progress - } - - @Override - public void onDownloadCompleted() { - result.success(null); - } - }); - } - - private Bitmap getBitmapFromData(Map imageData, MethodChannel.Result result) { - String model = (String) imageData.get("type"); - if (model != null && model.equals("bitmap")) { - try { - byte[] bitmapData = (byte[]) imageData.get("bitmapData"); - if (bitmapData == null) { - result.error("ImageDescriberError", "Bitmap data is null", null); - return null; - } - - try { - Map metadataMap = (Map) imageData.get("metadata"); - if (metadataMap != null) { - int width = Double.valueOf(Objects.requireNonNull(metadataMap.get("width")).toString()).intValue(); - int height = Double.valueOf(Objects.requireNonNull(metadataMap.get("height")).toString()).intValue(); - - Bitmap bitmap = Bitmap.createBitmap(width, height, Bitmap.Config.ARGB_8888); - java.nio.IntBuffer intBuffer = java.nio.IntBuffer.allocate(bitmapData.length / 4); - - for (int i = 0; i < bitmapData.length; i += 4) { - int r = bitmapData[i] & 0xFF; - int g = bitmapData[i + 1] & 0xFF; - int b = bitmapData[i + 2] & 0xFF; - int a = bitmapData[i + 3] & 0xFF; - intBuffer.put((a << 24) | (r << 16) | (g << 8) | b); - } - intBuffer.rewind(); - bitmap.copyPixelsFromBuffer(intBuffer); - return bitmap; - } - } catch (Exception e) { - Log.e("ImageError", "Error creating bitmap from raw data", e); - } - - Bitmap bitmap = BitmapFactory.decodeByteArray(bitmapData, 0, bitmapData.length); - if (bitmap == null) { - result.error("ImageDescriberError", "Failed to decode bitmap from the provided data", null); - return null; - } - return bitmap; - } catch (Exception e) { - Log.e("ImageError", "Getting Bitmap failed", e); - result.error("ImageDescriberError", e.toString(), null); - return null; - } - } else if (model != null && model.equals("file")) { - try { - String path = (String) imageData.get("path"); - if (path == null) { - result.error("ImageDescriberError", "Image file path is null", null); - return null; - } - File imageFile = new File(path); - if (!imageFile.exists()) { - result.error("ImageDescriberError", "Image file does not exist", null); - return null; - } - Bitmap bitmap = BitmapFactory.decodeFile(imageFile.getAbsolutePath()); - if (bitmap == null) { - result.error("ImageDescriberError", "Failed to decode bitmap from file", null); - return null; - } - return bitmap; - } catch (Exception e) { - Log.e("ImageError", "Getting Bitmap from file failed", e); - result.error("ImageDescriberError", e.toString(), null); - return null; - } - } else if (model != null && model.equals("bytes")) { - try { - byte[] bytes = (byte[]) imageData.get("bytes"); - if (bytes == null) { - result.error("ImageDescriberError", "Image bytes are null", null); - return null; - } - Bitmap bitmap = BitmapFactory.decodeByteArray(bytes, 0, bytes.length); - if (bitmap == null) { - result.error("ImageDescriberError", "Failed to decode bitmap from bytes", null); - return null; - } - return bitmap; - } catch (Exception e) { - Log.e("ImageError", "Getting Bitmap from bytes failed", e); - result.error("ImageDescriberError", e.toString(), null); - return null; - } - } - result.error("ImageDescriberError", "Invalid image type", null); - return null; - } - - private void runInference(MethodCall call, MethodChannel.Result result) { - String id = call.argument("id"); - Map imageData = call.argument("imageData"); - com.google.mlkit.genai.imagedescription.ImageDescriber imageDescriber = instances.get(id); - if (imageDescriber == null) { - imageDescriber = initialize(call); - instances.put(id, imageDescriber); - } - - Bitmap bitmap = getBitmapFromData(imageData, result); - if (bitmap == null) return; - - ImageDescriptionRequest request = ImageDescriptionRequest.builder(bitmap).build(); - ListenableFuture future = imageDescriber.runInference(request); - Futures.addCallback(future, new FutureCallback() { - @Override - public void onSuccess(com.google.mlkit.genai.imagedescription.ImageDescriptionResult imageDescriptionResult) { - Map response = new HashMap<>(); - response.put("description", imageDescriptionResult.getDescription()); - result.success(response); - } - - @Override - public void onFailure(Throwable e) { - result.error("InferenceError", e.toString(), null); - } - }, executor); - } - - private void runInferenceStreaming(MethodCall call, MethodChannel.Result result) { - // Streaming implementation would require EventChannel - result.notImplemented(); - } - - private void closeImageDescriber(MethodCall call) { - String id = call.argument("id"); - com.google.mlkit.genai.imagedescription.ImageDescriber imageDescriber = instances.get(id); - if (imageDescriber == null) return; - imageDescriber.close(); - instances.remove(id); - } -} diff --git a/packages/google_mlkit_genai_image_description/android/src/main/kotlin/com/google_mlkit_genai_image_description/GoogleMlKitGenaiImageDescriptionPlugin.kt b/packages/google_mlkit_genai_image_description/android/src/main/kotlin/com/google_mlkit_genai_image_description/GoogleMlKitGenaiImageDescriptionPlugin.kt new file mode 100644 index 00000000..b43ef2d7 --- /dev/null +++ b/packages/google_mlkit_genai_image_description/android/src/main/kotlin/com/google_mlkit_genai_image_description/GoogleMlKitGenaiImageDescriptionPlugin.kt @@ -0,0 +1,21 @@ +package com.google_mlkit_genai_image_description + +import io.flutter.embedding.engine.plugins.FlutterPlugin +import io.flutter.plugin.common.MethodChannel + +class GoogleMlKitGenaiImageDescriptionPlugin : FlutterPlugin { + private lateinit var channel: MethodChannel + + companion object { + private const val CHANNEL_NAME = "google_mlkit_genai_image_description" + } + + override fun onAttachedToEngine(flutterPluginBinding: FlutterPlugin.FlutterPluginBinding) { + channel = MethodChannel(flutterPluginBinding.binaryMessenger, CHANNEL_NAME) + channel.setMethodCallHandler(ImageDescriber(flutterPluginBinding.applicationContext)) + } + + override fun onDetachedFromEngine(binding: FlutterPlugin.FlutterPluginBinding) { + channel.setMethodCallHandler(null) + } +} diff --git a/packages/google_mlkit_genai_image_description/android/src/main/kotlin/com/google_mlkit_genai_image_description/ImageDescriber.kt b/packages/google_mlkit_genai_image_description/android/src/main/kotlin/com/google_mlkit_genai_image_description/ImageDescriber.kt new file mode 100644 index 00000000..0219e3a3 --- /dev/null +++ b/packages/google_mlkit_genai_image_description/android/src/main/kotlin/com/google_mlkit_genai_image_description/ImageDescriber.kt @@ -0,0 +1,247 @@ +package com.google_mlkit_genai_image_description + +import android.content.Context +import android.graphics.Bitmap +import android.graphics.BitmapFactory +import android.util.Log +import com.google.common.util.concurrent.FutureCallback +import com.google.common.util.concurrent.Futures +import com.google.mlkit.genai.common.DownloadCallback +import com.google.mlkit.genai.common.FeatureStatus +import com.google.mlkit.genai.common.GenAiException +import com.google.mlkit.genai.imagedescription.ImageDescriberOptions +import com.google.mlkit.genai.imagedescription.ImageDescription +import com.google.mlkit.genai.imagedescription.ImageDescriptionRequest +import io.flutter.plugin.common.MethodCall +import io.flutter.plugin.common.MethodChannel +import java.io.File +import java.nio.IntBuffer +import java.util.concurrent.Executor +import java.util.concurrent.Executors + +class ImageDescriber( + private val context: Context, +) : MethodChannel.MethodCallHandler { + private val instances = mutableMapOf() + + companion object { + private const val CHECK_FEATURE_STATUS = "genai#checkFeatureStatus" + private const val CLOSE = "genai#closeImageDescriber" + private const val DOWNLOAD_FEATURE = "genai#downloadFeature" + private const val RUN_INFERENCE = "genai#runInference" + private const val RUN_INFERENCE_STREAMING = "genai#runInferenceStreaming" + } + + private val executor: Executor = Executors.newSingleThreadExecutor() + + override fun onMethodCall( + call: MethodCall, + result: MethodChannel.Result, + ) { + when (call.method) { + CHECK_FEATURE_STATUS -> { + checkFeatureStatus(call, result) + } + + DOWNLOAD_FEATURE -> { + downloadFeature(call, result) + } + + RUN_INFERENCE -> { + runInference(call, result) + } + + RUN_INFERENCE_STREAMING -> { + result.notImplemented() + } + + CLOSE -> { + closeImageDescriber(call) + result.success(null) + } + + else -> { + result.notImplemented() + } + } + } + + private fun initialize(): com.google.mlkit.genai.imagedescription.ImageDescriber { + val options = ImageDescriberOptions.builder(context).build() + return ImageDescription.getClient(options) + } + + private fun getOrCreateInstance(call: MethodCall): Pair { + val id = call.argument("id")!! + val describer = instances.getOrPut(id) { initialize() } + return id to describer + } + + private fun checkFeatureStatus( + call: MethodCall, + result: MethodChannel.Result, + ) { + val (_, imageDescriber) = getOrCreateInstance(call) + + Futures.addCallback( + imageDescriber.checkFeatureStatus(), + object : FutureCallback { + override fun onSuccess(status: Int?) { + val statusValue = + when (status) { + FeatureStatus.UNAVAILABLE -> 0 + FeatureStatus.DOWNLOADABLE -> 1 + FeatureStatus.DOWNLOADING -> 2 + FeatureStatus.AVAILABLE -> 3 + else -> 0 + } + result.success(statusValue) + } + + override fun onFailure(e: Throwable) { + result.error("ImageDescriberError", e.toString(), null) + } + }, + executor, + ) + } + + private fun downloadFeature( + call: MethodCall, + result: MethodChannel.Result, + ) { + val (_, imageDescriber) = getOrCreateInstance(call) + + imageDescriber.downloadFeature( + object : DownloadCallback { + override fun onDownloadStarted(bytesToDownload: Long) {} + + override fun onDownloadFailed(e: GenAiException) { + result.error("DownloadError", e.toString(), null) + } + + override fun onDownloadProgress(totalBytesDownloaded: Long) {} + + override fun onDownloadCompleted() { + result.success(null) + } + }, + ) + } + + private fun getBitmapFromData( + imageData: Map, + result: MethodChannel.Result, + ): Bitmap? { + return when (imageData["type"] as? String) { + "bitmap" -> { + val bitmapData = + imageData["bitmapData"] as? ByteArray + ?: return run { + result.error("ImageDescriberError", "Bitmap data is null", null) + null + } + + try { + val metadataMap = imageData["metadata"] as? Map<*, *> + if (metadataMap != null) { + val width = metadataMap["width"].toString().toDouble().toInt() + val height = metadataMap["height"].toString().toDouble().toInt() + val intBuffer = IntBuffer.allocate(bitmapData.size / 4) + + for (i in bitmapData.indices step 4) { + val r = bitmapData[i].toInt() and 0xFF + val g = bitmapData[i + 1].toInt() and 0xFF + val b = bitmapData[i + 2].toInt() and 0xFF + val a = bitmapData[i + 3].toInt() and 0xFF + intBuffer.put((a shl 24) or (r shl 16) or (g shl 8) or b) + } + intBuffer.rewind() + val createdBitmap = Bitmap.createBitmap(width, height, Bitmap.Config.ARGB_8888) + createdBitmap.copyPixelsFromBuffer(intBuffer) + return createdBitmap + } + } catch (e: Exception) { + Log.e("ImageError", "Error creating bitmap from raw data", e) + } + + BitmapFactory.decodeByteArray(bitmapData, 0, bitmapData.size) + ?: return run { + result.error("ImageDescriberError", "Failed to decode bitmap from the provided data", null) + null + } + } + + "file" -> { + val path = + imageData["path"] as? String + ?: return run { + result.error("ImageDescriberError", "Image file path is null", null) + null + } + val imageFile = File(path) + if (!imageFile.exists()) { + result.error("ImageDescriberError", "Image file does not exist", null) + return null + } + BitmapFactory.decodeFile(imageFile.absolutePath) + ?: return run { + result.error("ImageDescriberError", "Failed to decode bitmap from file", null) + null + } + } + + "bytes" -> { + val bytes = + imageData["bytes"] as? ByteArray + ?: return run { + result.error("ImageDescriberError", "Image bytes are null", null) + null + } + BitmapFactory.decodeByteArray(bytes, 0, bytes.size) + ?: return run { + result.error("ImageDescriberError", "Failed to decode bitmap from bytes", null) + null + } + } + + else -> { + run { + result.error("ImageDescriberError", "Invalid image type", null) + null + } + } + } + } + + private fun runInference( + call: MethodCall, + result: MethodChannel.Result, + ) { + val (_, imageDescriber) = getOrCreateInstance(call) + val imageData = call.argument>("imageData") ?: return + + val bitmap = getBitmapFromData(imageData, result) ?: return + val request = ImageDescriptionRequest.builder(bitmap).build() + + Futures.addCallback( + imageDescriber.runInference(request), + object : FutureCallback { + override fun onSuccess(imageDescriptionResult: com.google.mlkit.genai.imagedescription.ImageDescriptionResult?) { + result.success(mapOf("description" to imageDescriptionResult?.description)) + } + + override fun onFailure(e: Throwable) { + result.error("InferenceError", e.toString(), null) + } + }, + executor, + ) + } + + private fun closeImageDescriber(call: MethodCall) { + val id = call.argument("id") ?: return + instances[id]?.close() + instances.remove(id) + } +} diff --git a/packages/google_mlkit_genai_prompt/android/build.gradle b/packages/google_mlkit_genai_prompt/android/build.gradle index 34b7e690..84a0b5a3 100644 --- a/packages/google_mlkit_genai_prompt/android/build.gradle +++ b/packages/google_mlkit_genai_prompt/android/build.gradle @@ -2,6 +2,7 @@ group = "com.google_mlkit_genai_prompt" version = "1.0" buildscript { + ext.kotlin_version = "2.2.20" repositories { google() mavenCentral() @@ -9,6 +10,7 @@ buildscript { dependencies { classpath("com.android.tools.build:gradle:8.13.0") + classpath('org.jetbrains.kotlin:kotlin-gradle-plugin:2.2.20') } } @@ -20,6 +22,7 @@ rootProject.allprojects { } apply plugin: "com.android.library" +apply plugin: "kotlin-android" android { namespace = "com.google_mlkit_genai_prompt" @@ -31,6 +34,15 @@ android { targetCompatibility = JavaVersion.VERSION_11 } + kotlinOptions { + jvmTarget = '11' + } + + sourceSets { + main.java.srcDirs += "src/main/kotlin" + } + + defaultConfig { minSdk = 26 } diff --git a/packages/google_mlkit_genai_prompt/android/gradle/wrapper/gradle-wrapper.properties b/packages/google_mlkit_genai_prompt/android/gradle/wrapper/gradle-wrapper.properties new file mode 100644 index 00000000..128196a7 --- /dev/null +++ b/packages/google_mlkit_genai_prompt/android/gradle/wrapper/gradle-wrapper.properties @@ -0,0 +1,7 @@ +distributionBase=GRADLE_USER_HOME +distributionPath=wrapper/dists +distributionUrl=https\://services.gradle.org/distributions/gradle-9.0-milestone-1-bin.zip +networkTimeout=10000 +validateDistributionUrl=true +zipStoreBase=GRADLE_USER_HOME +zipStorePath=wrapper/dists diff --git a/packages/google_mlkit_genai_prompt/android/src/main/java/com/google_mlkit_genai_prompt/GoogleMlKitGenaiPromptPlugin.java b/packages/google_mlkit_genai_prompt/android/src/main/java/com/google_mlkit_genai_prompt/GoogleMlKitGenaiPromptPlugin.java deleted file mode 100644 index 1be8cfc3..00000000 --- a/packages/google_mlkit_genai_prompt/android/src/main/java/com/google_mlkit_genai_prompt/GoogleMlKitGenaiPromptPlugin.java +++ /dev/null @@ -1,22 +0,0 @@ -package com.google_mlkit_genai_prompt; - -import androidx.annotation.NonNull; - -import io.flutter.embedding.engine.plugins.FlutterPlugin; -import io.flutter.plugin.common.MethodChannel; - -public class GoogleMlKitGenaiPromptPlugin implements FlutterPlugin { - private MethodChannel channel; - private static final String channelName = "google_mlkit_genai_prompt"; - - @Override - public void onAttachedToEngine(@NonNull FlutterPluginBinding flutterPluginBinding) { - channel = new MethodChannel(flutterPluginBinding.getBinaryMessenger(), channelName); - channel.setMethodCallHandler(new Prompt(flutterPluginBinding.getApplicationContext())); - } - - @Override - public void onDetachedFromEngine(@NonNull FlutterPluginBinding binding) { - channel.setMethodCallHandler(null); - } -} diff --git a/packages/google_mlkit_genai_prompt/android/src/main/java/com/google_mlkit_genai_prompt/Prompt.java b/packages/google_mlkit_genai_prompt/android/src/main/java/com/google_mlkit_genai_prompt/Prompt.java deleted file mode 100644 index a4a21a47..00000000 --- a/packages/google_mlkit_genai_prompt/android/src/main/java/com/google_mlkit_genai_prompt/Prompt.java +++ /dev/null @@ -1,166 +0,0 @@ -package com.google_mlkit_genai_prompt; - -import android.content.Context; - -import androidx.annotation.NonNull; - -import com.google.common.util.concurrent.Futures; -import com.google.common.util.concurrent.ListenableFuture; -import com.google.common.util.concurrent.FutureCallback; - -import java.util.HashMap; -import java.util.Map; -import java.util.concurrent.Executor; -import java.util.concurrent.Executors; - -import io.flutter.plugin.common.MethodCall; -import io.flutter.plugin.common.MethodChannel; - -public class Prompt implements MethodChannel.MethodCallHandler { - private static final String CHECK_FEATURE_STATUS = "genai#checkFeatureStatus"; - private static final String DOWNLOAD_FEATURE = "genai#downloadFeature"; - private static final String RUN_INFERENCE = "genai#runInference"; - private static final String RUN_INFERENCE_STREAMING = "genai#runInferenceStreaming"; - private static final String CLOSE = "genai#closePrompt"; - - private final Context context; - private final Map instances = new HashMap<>(); - private final Executor executor = Executors.newSingleThreadExecutor(); - - public Prompt(Context context) { - this.context = context; - } - - @Override - public void onMethodCall(@NonNull MethodCall call, @NonNull MethodChannel.Result result) { - String method = call.method; - switch (method) { - case CHECK_FEATURE_STATUS: - checkFeatureStatus(call, result); - break; - case DOWNLOAD_FEATURE: - downloadFeature(call, result); - break; - case RUN_INFERENCE: - runInference(call, result); - break; - case RUN_INFERENCE_STREAMING: - runInferenceStreaming(call, result); - break; - case CLOSE: - closePrompt(call); - result.success(null); - break; - default: - result.notImplemented(); - break; - } - } - - private Object initialize(MethodCall call) { - // Prompt API initialization - structure may vary - // This is a placeholder implementation - try { - // Try to use Generation.INSTANCE.getClient() for Java - Object generationInstance = Class.forName("com.google.mlkit.genai.prompt.Generation").getField("INSTANCE").get(null); - Object generativeModel = generationInstance.getClass().getMethod("getClient").invoke(generationInstance); - Object generativeModelFutures = Class.forName("com.google.mlkit.genai.prompt.GenerativeModelFutures") - .getMethod("from", Class.forName("com.google.mlkit.genai.prompt.GenerativeModel")) - .invoke(null, generativeModel); - return generativeModelFutures; - } catch (Exception e) { - // If reflection fails, return a placeholder - return new Object(); - } - } - - private void checkFeatureStatus(MethodCall call, MethodChannel.Result result) { - String id = call.argument("id"); - Object generativeModel = instances.get(id); - if (generativeModel == null) { - generativeModel = initialize(call); - instances.put(id, generativeModel); - } - - try { - ListenableFuture future = (ListenableFuture) generativeModel.getClass() - .getMethod("checkStatus").invoke(generativeModel); - Futures.addCallback(future, new FutureCallback() { - @Override - public void onSuccess(Integer status) { - int statusValue; - if (status == com.google.mlkit.genai.common.FeatureStatus.UNAVAILABLE) { - statusValue = 0; - } else if (status == com.google.mlkit.genai.common.FeatureStatus.DOWNLOADABLE) { - statusValue = 1; - } else if (status == com.google.mlkit.genai.common.FeatureStatus.DOWNLOADING) { - statusValue = 2; - } else if (status == com.google.mlkit.genai.common.FeatureStatus.AVAILABLE) { - statusValue = 3; - } else { - statusValue = 0; - } - result.success(statusValue); - } - - @Override - public void onFailure(Throwable e) { - result.error("PromptError", e.toString(), null); - } - }, executor); - } catch (Exception e) { - result.error("PromptError", "Failed to check status: " + e.toString(), null); - } - } - - private void downloadFeature(MethodCall call, MethodChannel.Result result) { - String id = call.argument("id"); - Object generativeModel = instances.get(id); - if (generativeModel == null) { - generativeModel = initialize(call); - instances.put(id, generativeModel); - } - - try { - generativeModel.getClass().getMethod("download", com.google.mlkit.genai.common.DownloadCallback.class) - .invoke(generativeModel, new com.google.mlkit.genai.common.DownloadCallback() { - @Override - public void onDownloadStarted(long bytesToDownload) { - // Handle download started - } - - @Override - public void onDownloadFailed(com.google.mlkit.genai.common.GenAiException e) { - result.error("DownloadError", e.toString(), null); - } - - @Override - public void onDownloadProgress(long totalBytesDownloaded) { - // Handle download progress - } - - @Override - public void onDownloadCompleted() { - result.success(null); - } - }); - } catch (Exception e) { - result.error("PromptError", "Failed to download: " + e.toString(), null); - } - } - - private void runInference(MethodCall call, MethodChannel.Result result) { - // Prompt API inference - placeholder implementation - result.error("PromptError", "Prompt API inference not yet fully implemented", null); - } - - private void runInferenceStreaming(MethodCall call, MethodChannel.Result result) { - // Streaming implementation would require EventChannel - result.notImplemented(); - } - - private void closePrompt(MethodCall call) { - String id = call.argument("id"); - instances.remove(id); - } -} diff --git a/packages/google_mlkit_genai_prompt/android/src/main/kotlin/com/google_mlkit_genai_prompt/GoogleMlKitGenaiPromptPlugin.kt b/packages/google_mlkit_genai_prompt/android/src/main/kotlin/com/google_mlkit_genai_prompt/GoogleMlKitGenaiPromptPlugin.kt new file mode 100644 index 00000000..1ff5490a --- /dev/null +++ b/packages/google_mlkit_genai_prompt/android/src/main/kotlin/com/google_mlkit_genai_prompt/GoogleMlKitGenaiPromptPlugin.kt @@ -0,0 +1,21 @@ +package com.google_mlkit_genai_prompt + +import io.flutter.embedding.engine.plugins.FlutterPlugin +import io.flutter.plugin.common.MethodChannel + +class GoogleMlKitGenaiPromptPlugin : FlutterPlugin { + private lateinit var channel: MethodChannel + + companion object { + private const val CHANNEL_NAME = "google_mlkit_genai_prompt" + } + + override fun onAttachedToEngine(flutterPluginBinding: FlutterPlugin.FlutterPluginBinding) { + channel = MethodChannel(flutterPluginBinding.binaryMessenger, CHANNEL_NAME) + channel.setMethodCallHandler(Prompt(flutterPluginBinding.applicationContext)) + } + + override fun onDetachedFromEngine(binding: FlutterPlugin.FlutterPluginBinding) { + channel.setMethodCallHandler(null) + } +} diff --git a/packages/google_mlkit_genai_prompt/android/src/main/kotlin/com/google_mlkit_genai_prompt/Prompt.kt b/packages/google_mlkit_genai_prompt/android/src/main/kotlin/com/google_mlkit_genai_prompt/Prompt.kt new file mode 100644 index 00000000..d5e5b96a --- /dev/null +++ b/packages/google_mlkit_genai_prompt/android/src/main/kotlin/com/google_mlkit_genai_prompt/Prompt.kt @@ -0,0 +1,164 @@ +package com.google_mlkit_genai_prompt + +import android.content.Context +import com.google.common.util.concurrent.FutureCallback +import com.google.common.util.concurrent.Futures +import com.google.common.util.concurrent.ListenableFuture +import com.google.mlkit.genai.common.DownloadCallback +import com.google.mlkit.genai.common.FeatureStatus +import com.google.mlkit.genai.common.GenAiException +import io.flutter.plugin.common.MethodCall +import io.flutter.plugin.common.MethodChannel +import java.util.concurrent.Executor +import java.util.concurrent.Executors + +class Prompt( + private val context: Context, +) : MethodChannel.MethodCallHandler { + private val instances = mutableMapOf() + + companion object { + private const val CHECK_FEATURE_STATUS = "genai#checkFeatureStatus" + private const val DOWNLOAD_FEATURE = "genai#downloadFeature" + private const val RUN_INFERENCE = "genai#runInference" + private const val RUN_INFERENCE_STREAMING = "genai#runInferenceStreaming" + private const val CLOSE = "genai#closePrompt" + } + + private val executor: Executor = Executors.newSingleThreadExecutor() + + override fun onMethodCall( + call: MethodCall, + result: MethodChannel.Result, + ) { + when (call.method) { + CHECK_FEATURE_STATUS -> { + checkFeatureStatus(call, result) + } + + DOWNLOAD_FEATURE -> { + downloadFeature(call, result) + } + + RUN_INFERENCE -> { + runInference(call, result) + } + + RUN_INFERENCE_STREAMING -> { + result.notImplemented() + } + + CLOSE -> { + closePrompt(call) + result.success(null) + } + + else -> { + result.notImplemented() + } + } + } + + private fun initialize(call: MethodCall): Any = + try { + val generationInstance = + Class + .forName("com.google.mlkit.genai.prompt.Generation") + .getField("INSTANCE") + .get(null) + val generativeModel = + generationInstance + ?.javaClass + ?.getMethod("getClient") + ?.invoke(generationInstance) + Class + .forName("com.google.mlkit.genai.prompt.GenerativeModelFutures") + .getMethod("from", Class.forName("com.google.mlkit.genai.prompt.GenerativeModel")) + .invoke(null, generativeModel) ?: Any() + } catch (e: Exception) { + Any() + } + + private fun checkFeatureStatus( + call: MethodCall, + result: MethodChannel.Result, + ) { + val id = call.argument("id") ?: return result.error("PromptError", "Missing id", null) + val generativeModel = instances.getOrPut(id) { initialize(call) } + + try { + @Suppress("UNCHECKED_CAST") + val future = + generativeModel.javaClass + .getMethod("checkStatus") + .invoke(generativeModel) as ListenableFuture + + Futures.addCallback( + future, + object : FutureCallback { + override fun onSuccess(status: Int?) { + val statusValue = + when (status) { + FeatureStatus.UNAVAILABLE -> 0 + FeatureStatus.DOWNLOADABLE -> 1 + FeatureStatus.DOWNLOADING -> 2 + FeatureStatus.AVAILABLE -> 3 + else -> 0 + } + result.success(statusValue) + } + + override fun onFailure(e: Throwable) { + result.error("PromptError", e.toString(), null) + } + }, + executor, + ) + } catch (e: Exception) { + result.error("PromptError", "Failed to check status: $e", null) + } + } + + private fun downloadFeature( + call: MethodCall, + result: MethodChannel.Result, + ) { + val id = call.argument("id") ?: return result.error("PromptError", "Missing id", null) + val generativeModel = instances.getOrPut(id) { initialize(call) } + + try { + generativeModel.javaClass + .getMethod("download", DownloadCallback::class.java) + .invoke( + generativeModel, + object : DownloadCallback { + override fun onDownloadStarted(p0: Long) {} + + override fun onDownloadFailed(e: GenAiException) { + result.error("DownloadError", e.toString(), null) + } + + override fun onDownloadProgress(totalBytesDownloaded: Long) {} + + override fun onDownloadCompleted() { + result.success(null) + } + }, + ) + } catch (e: Exception) { + result.error("PromptError", "Failed to download $e", null) + } + } + + private fun runInference( + call: MethodCall, + result: MethodChannel.Result, + ) { + result.error("PromptError", "Prompt API inference not yet fully implemented", null) + } + + private fun closePrompt(call: MethodCall) { + val id = call.argument("id") ?: return + instances.remove(id) + } +} diff --git a/packages/google_mlkit_genai_proofreading/android/build.gradle b/packages/google_mlkit_genai_proofreading/android/build.gradle index 161a470d..57f85534 100644 --- a/packages/google_mlkit_genai_proofreading/android/build.gradle +++ b/packages/google_mlkit_genai_proofreading/android/build.gradle @@ -2,6 +2,7 @@ group = "com.google_mlkit_genai_proofreading" version = "1.0" buildscript { + ext.kotlin_version = "2.2.20" repositories { google() mavenCentral() @@ -9,6 +10,7 @@ buildscript { dependencies { classpath("com.android.tools.build:gradle:8.13.0") + classpath('org.jetbrains.kotlin:kotlin-gradle-plugin:2.2.20') } } @@ -20,6 +22,7 @@ rootProject.allprojects { } apply plugin: "com.android.library" +apply plugin: "kotlin-android" android { namespace = "com.google_mlkit_genai_proofreading" @@ -31,6 +34,14 @@ android { targetCompatibility = JavaVersion.VERSION_11 } + kotlinOptions { + jvmTarget = '11' + } + + sourceSets { + main.java.srcDirs += "src/main/kotlin" + } + defaultConfig { minSdk = 26 } diff --git a/packages/google_mlkit_genai_proofreading/android/gradle/wrapper/gradle-wrapper.properties b/packages/google_mlkit_genai_proofreading/android/gradle/wrapper/gradle-wrapper.properties new file mode 100644 index 00000000..128196a7 --- /dev/null +++ b/packages/google_mlkit_genai_proofreading/android/gradle/wrapper/gradle-wrapper.properties @@ -0,0 +1,7 @@ +distributionBase=GRADLE_USER_HOME +distributionPath=wrapper/dists +distributionUrl=https\://services.gradle.org/distributions/gradle-9.0-milestone-1-bin.zip +networkTimeout=10000 +validateDistributionUrl=true +zipStoreBase=GRADLE_USER_HOME +zipStorePath=wrapper/dists diff --git a/packages/google_mlkit_genai_proofreading/android/src/main/java/com/google_mlkit_genai_proofreading/GoogleMlKitGenaiProofreadingPlugin.java b/packages/google_mlkit_genai_proofreading/android/src/main/java/com/google_mlkit_genai_proofreading/GoogleMlKitGenaiProofreadingPlugin.java deleted file mode 100644 index d4047199..00000000 --- a/packages/google_mlkit_genai_proofreading/android/src/main/java/com/google_mlkit_genai_proofreading/GoogleMlKitGenaiProofreadingPlugin.java +++ /dev/null @@ -1,22 +0,0 @@ -package com.google_mlkit_genai_proofreading; - -import androidx.annotation.NonNull; - -import io.flutter.embedding.engine.plugins.FlutterPlugin; -import io.flutter.plugin.common.MethodChannel; - -public class GoogleMlKitGenaiProofreadingPlugin implements FlutterPlugin { - private MethodChannel channel; - private static final String channelName = "google_mlkit_genai_proofreading"; - - @Override - public void onAttachedToEngine(@NonNull FlutterPluginBinding flutterPluginBinding) { - channel = new MethodChannel(flutterPluginBinding.getBinaryMessenger(), channelName); - channel.setMethodCallHandler(new Proofreader(flutterPluginBinding.getApplicationContext())); - } - - @Override - public void onDetachedFromEngine(@NonNull FlutterPluginBinding binding) { - channel.setMethodCallHandler(null); - } -} diff --git a/packages/google_mlkit_genai_proofreading/android/src/main/java/com/google_mlkit_genai_proofreading/Proofreader.java b/packages/google_mlkit_genai_proofreading/android/src/main/java/com/google_mlkit_genai_proofreading/Proofreader.java deleted file mode 100644 index ecf36346..00000000 --- a/packages/google_mlkit_genai_proofreading/android/src/main/java/com/google_mlkit_genai_proofreading/Proofreader.java +++ /dev/null @@ -1,189 +0,0 @@ -package com.google_mlkit_genai_proofreading; - -import android.content.Context; - -import androidx.annotation.NonNull; - -import com.google.mlkit.genai.proofreading.Proofreading; -import com.google.mlkit.genai.proofreading.ProofreadingRequest; -import com.google.mlkit.genai.proofreading.ProofreaderOptions; - -import com.google.common.util.concurrent.Futures; -import com.google.common.util.concurrent.ListenableFuture; -import com.google.common.util.concurrent.FutureCallback; - -import java.util.ArrayList; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.concurrent.Executor; -import java.util.concurrent.Executors; - -import io.flutter.plugin.common.MethodCall; -import io.flutter.plugin.common.MethodChannel; - -public class Proofreader implements MethodChannel.MethodCallHandler { - private static final String CHECK_FEATURE_STATUS = "genai#checkFeatureStatus"; - private static final String DOWNLOAD_FEATURE = "genai#downloadFeature"; - private static final String RUN_INFERENCE = "genai#runInference"; - private static final String RUN_INFERENCE_STREAMING = "genai#runInferenceStreaming"; - private static final String CLOSE = "genai#closeProofreader"; - - private final Context context; - private final Map instances = new HashMap<>(); - private final Executor executor = Executors.newSingleThreadExecutor(); - - public Proofreader(Context context) { - this.context = context; - } - - @Override - public void onMethodCall(@NonNull MethodCall call, @NonNull MethodChannel.Result result) { - String method = call.method; - switch (method) { - case CHECK_FEATURE_STATUS: - checkFeatureStatus(call, result); - break; - case DOWNLOAD_FEATURE: - downloadFeature(call, result); - break; - case RUN_INFERENCE: - runInference(call, result); - break; - case RUN_INFERENCE_STREAMING: - runInferenceStreaming(call, result); - break; - case CLOSE: - closeProofreader(call); - result.success(null); - break; - default: - result.notImplemented(); - break; - } - } - - private com.google.mlkit.genai.proofreading.Proofreader initialize(MethodCall call) { - // Use basic ProofreaderOptions builder - API structure may vary - ProofreaderOptions options = ProofreaderOptions.builder(context).build(); - return Proofreading.getClient(options); - } - - private void checkFeatureStatus(MethodCall call, MethodChannel.Result result) { - String id = call.argument("id"); - com.google.mlkit.genai.proofreading.Proofreader proofreader = instances.get(id); - if (proofreader == null) { - proofreader = initialize(call); - instances.put(id, proofreader); - } - - ListenableFuture future = proofreader.checkFeatureStatus(); - Futures.addCallback(future, new FutureCallback() { - @Override - public void onSuccess(Integer status) { - int statusValue; - if (status == com.google.mlkit.genai.common.FeatureStatus.UNAVAILABLE) { - statusValue = 0; - } else if (status == com.google.mlkit.genai.common.FeatureStatus.DOWNLOADABLE) { - statusValue = 1; - } else if (status == com.google.mlkit.genai.common.FeatureStatus.DOWNLOADING) { - statusValue = 2; - } else if (status == com.google.mlkit.genai.common.FeatureStatus.AVAILABLE) { - statusValue = 3; - } else { - statusValue = 0; - } - result.success(statusValue); - } - - @Override - public void onFailure(Throwable e) { - result.error("ProofreaderError", e.toString(), null); - } - }, executor); - } - - private void downloadFeature(MethodCall call, MethodChannel.Result result) { - String id = call.argument("id"); - com.google.mlkit.genai.proofreading.Proofreader proofreader = instances.get(id); - if (proofreader == null) { - proofreader = initialize(call); - instances.put(id, proofreader); - } - - proofreader.downloadFeature(new com.google.mlkit.genai.common.DownloadCallback() { - @Override - public void onDownloadStarted(long bytesToDownload) { - // Handle download started - } - - @Override - public void onDownloadFailed(com.google.mlkit.genai.common.GenAiException e) { - result.error("DownloadError", e.toString(), null); - } - - @Override - public void onDownloadProgress(long totalBytesDownloaded) { - // Handle download progress - } - - @Override - public void onDownloadCompleted() { - result.success(null); - } - }); - } - - private void runInference(MethodCall call, MethodChannel.Result result) { - String id = call.argument("id"); - String text = call.argument("text"); - com.google.mlkit.genai.proofreading.Proofreader proofreader = instances.get(id); - if (proofreader == null) { - proofreader = initialize(call); - instances.put(id, proofreader); - } - - ProofreadingRequest request = ProofreadingRequest.builder(text).build(); - ListenableFuture future = proofreader.runInference(request); - Futures.addCallback(future, new FutureCallback() { - @Override - public void onSuccess(com.google.mlkit.genai.proofreading.ProofreadingResult proofreadingResult) { - Map response = new HashMap<>(); - // ProofreadingResult may have getCorrectedText() or similar method - // Use reflection to find the correct method - try { - String correctedText = (String) proofreadingResult.getClass().getMethod("getCorrectedText").invoke(proofreadingResult); - response.put("text", correctedText); - } catch (Exception e1) { - try { - // Try getText() if getCorrectedText() doesn't exist - String correctedText = (String) proofreadingResult.getClass().getMethod("getText").invoke(proofreadingResult); - response.put("text", correctedText); - } catch (Exception e2) { - // If both fail, return empty string - response.put("text", ""); - } - } - result.success(response); - } - - @Override - public void onFailure(Throwable e) { - result.error("InferenceError", e.toString(), null); - } - }, executor); - } - - private void runInferenceStreaming(MethodCall call, MethodChannel.Result result) { - // Streaming implementation would require EventChannel - result.notImplemented(); - } - - private void closeProofreader(MethodCall call) { - String id = call.argument("id"); - com.google.mlkit.genai.proofreading.Proofreader proofreader = instances.get(id); - if (proofreader == null) return; - proofreader.close(); - instances.remove(id); - } -} diff --git a/packages/google_mlkit_genai_proofreading/android/src/main/kotlin/com/google_mlkit_genai_proofreading/GoogleMlKitGenaiProofreadingPlugin.kt b/packages/google_mlkit_genai_proofreading/android/src/main/kotlin/com/google_mlkit_genai_proofreading/GoogleMlKitGenaiProofreadingPlugin.kt new file mode 100644 index 00000000..54c9ba94 --- /dev/null +++ b/packages/google_mlkit_genai_proofreading/android/src/main/kotlin/com/google_mlkit_genai_proofreading/GoogleMlKitGenaiProofreadingPlugin.kt @@ -0,0 +1,23 @@ +package com.google_mlkit_genai_proofreading + +import android.content.Context +import io.flutter.embedding.engine.plugins.FlutterPlugin +import io.flutter.plugin.common.MethodCall +import io.flutter.plugin.common.MethodChannel + +class GoogleMlKitGenaiProofreadingPlugin : FlutterPlugin { + private lateinit var channel: MethodChannel + + override fun onAttachedToEngine(flutterPluginBinding: FlutterPlugin.FlutterPluginBinding) { + channel = MethodChannel(flutterPluginBinding.binaryMessenger, CHANNEL_NAME) + channel.setMethodCallHandler(Proofreader(flutterPluginBinding.applicationContext)) + } + + override fun onDetachedFromEngine(p0: FlutterPlugin.FlutterPluginBinding) { + channel.setMethodCallHandler(null) + } + + companion object { + private const val CHANNEL_NAME = "google_mlkit_genai_proofreading" + } +} diff --git a/packages/google_mlkit_genai_proofreading/android/src/main/kotlin/com/google_mlkit_genai_proofreading/Proofreader.kt b/packages/google_mlkit_genai_proofreading/android/src/main/kotlin/com/google_mlkit_genai_proofreading/Proofreader.kt new file mode 100644 index 00000000..b3982661 --- /dev/null +++ b/packages/google_mlkit_genai_proofreading/android/src/main/kotlin/com/google_mlkit_genai_proofreading/Proofreader.kt @@ -0,0 +1,160 @@ +package com.google_mlkit_genai_proofreading + +import android.content.Context +import com.google.common.util.concurrent.FutureCallback +import com.google.common.util.concurrent.Futures +import com.google.mlkit.genai.common.DownloadCallback +import com.google.mlkit.genai.common.FeatureStatus +import com.google.mlkit.genai.common.GenAiException +import com.google.mlkit.genai.proofreading.ProofreaderOptions +import com.google.mlkit.genai.proofreading.Proofreading +import com.google.mlkit.genai.proofreading.ProofreadingRequest +import com.google.mlkit.genai.proofreading.ProofreadingResult +import io.flutter.plugin.common.MethodCall +import io.flutter.plugin.common.MethodChannel +import java.util.concurrent.Executors + +class Proofreader( + private val context: Context, +) : MethodChannel.MethodCallHandler { + companion object { + private const val CHECK_FEATURE_STATUS = "genai#checkFeatureStatus" + private const val DOWNLOAD_FEATURE = "genai#downloadFeature" + private const val RUN_INFERENCE = "genai#runInference" + private const val RUN_INFERENCE_STREAMING = "genai#runInferenceStreaming" + private const val CLOSE = "genai#closeProofreader" + } + + private val instances = mutableMapOf() + private val executor = Executors.newSingleThreadExecutor() + + override fun onMethodCall( + call: MethodCall, + result: MethodChannel.Result, + ) { + when (call.method) { + CHECK_FEATURE_STATUS -> { + checkFeatureStatus(call, result) + } + + DOWNLOAD_FEATURE -> { + downloadFeature(call, result) + } + + RUN_INFERENCE -> { + runInference(call, result) + } + + RUN_INFERENCE_STREAMING -> { + result.notImplemented() + } + + CLOSE -> { + closeProofreader(call) + result.success(null) + } + + else -> { + result.notImplemented() + } + } + } + + private fun initialize(call: MethodCall): com.google.mlkit.genai.proofreading.Proofreader { + val options = ProofreaderOptions.builder(context).build() + return Proofreading.getClient(options) + } + + private fun checkFeatureStatus( + call: MethodCall, + result: MethodChannel.Result, + ) { + val id = call.argument("id") ?: return result.error("ProofreaderError", "Missing id", null) + val proofreader = instances.getOrPut(id) { initialize(call) } + + val future = proofreader.checkFeatureStatus() + Futures.addCallback( + future, + object : FutureCallback { + override fun onSuccess(status: Int?) { + val statusValue = + when (status) { + FeatureStatus.UNAVAILABLE -> 0 + FeatureStatus.DOWNLOADABLE -> 1 + FeatureStatus.DOWNLOADING -> 2 + FeatureStatus.AVAILABLE -> 3 + else -> 0 + } + result.success(statusValue) + } + + override fun onFailure(e: Throwable) { + result.error("ProofreaderError", e.toString(), null) + } + }, + executor, + ) + } + + private fun downloadFeature( + call: MethodCall, + result: MethodChannel.Result, + ) { + val id = call.argument("id") ?: return result.error("ProofreaderError", "Missing id", null) + val proofreader = instances.getOrPut(id) { initialize(call) } + + proofreader.downloadFeature( + object : DownloadCallback { + override fun onDownloadStarted(bytesToDownload: Long) {} + + override fun onDownloadFailed(e: GenAiException) { + result.error("DownloadError", e.toString(), null) + } + + override fun onDownloadProgress(totalBytesDownloaded: Long) {} + + override fun onDownloadCompleted() { + result.success(null) + } + }, + ) + } + + private fun runInference( + call: MethodCall, + result: MethodChannel.Result, + ) { + val id = call.argument("id") ?: return result.error("ProofreaderError", "Missing id", null) + val text = call.argument("text") ?: return result.error("ProofreaderError", "Missing text", null) + val proofreader = instances.getOrPut(id) { initialize(call) } + + val request = ProofreadingRequest.builder(text).build() + val future = proofreader.runInference(request) + Futures.addCallback( + future, + object : FutureCallback { + override fun onSuccess(proofreadingResult: ProofreadingResult?) { + val correctedText = + proofreadingResult?.let { + runCatching { + it.javaClass.getMethod("getCorrectedText").invoke(it) as? String + }.getOrNull() ?: runCatching { + it.javaClass.getMethod("getText").invoke(it) as? String + }.getOrNull() ?: "" + } ?: "" + result.success(mapOf("text" to correctedText)) + } + + override fun onFailure(e: Throwable) { + result.error("InferenceError", e.toString(), null) + } + }, + executor, + ) + } + + private fun closeProofreader(call: MethodCall) { + val id = call.argument("id") ?: return + instances.remove(id)?.close() + } +} diff --git a/packages/google_mlkit_genai_rewriting/android/build.gradle b/packages/google_mlkit_genai_rewriting/android/build.gradle index e88fa452..96a4ae96 100644 --- a/packages/google_mlkit_genai_rewriting/android/build.gradle +++ b/packages/google_mlkit_genai_rewriting/android/build.gradle @@ -2,6 +2,7 @@ group = "com.google_mlkit_genai_rewriting" version = "1.0" buildscript { + ext.kotlin_version = "2.2.20" repositories { google() mavenCentral() @@ -9,6 +10,7 @@ buildscript { dependencies { classpath("com.android.tools.build:gradle:8.13.0") + classpath('org.jetbrains.kotlin:kotlin-gradle-plugin:2.2.20') } } @@ -20,6 +22,7 @@ rootProject.allprojects { } apply plugin: "com.android.library" +apply plugin: "kotlin-android" android { namespace = "com.google_mlkit_genai_rewriting" @@ -31,6 +34,14 @@ android { targetCompatibility = JavaVersion.VERSION_11 } + kotlinOptions { + jvmTarget = '11' + } + + sourceSets { + main.java.srcDirs += "src/main/kotlin" + } + defaultConfig { minSdk = 26 } diff --git a/packages/google_mlkit_genai_rewriting/android/gradle/wrapper/gradle-wrapper.properties b/packages/google_mlkit_genai_rewriting/android/gradle/wrapper/gradle-wrapper.properties new file mode 100644 index 00000000..128196a7 --- /dev/null +++ b/packages/google_mlkit_genai_rewriting/android/gradle/wrapper/gradle-wrapper.properties @@ -0,0 +1,7 @@ +distributionBase=GRADLE_USER_HOME +distributionPath=wrapper/dists +distributionUrl=https\://services.gradle.org/distributions/gradle-9.0-milestone-1-bin.zip +networkTimeout=10000 +validateDistributionUrl=true +zipStoreBase=GRADLE_USER_HOME +zipStorePath=wrapper/dists diff --git a/packages/google_mlkit_genai_rewriting/android/src/main/java/com/google_mlkit_genai_rewriting/GoogleMlKitGenaiRewritingPlugin.java b/packages/google_mlkit_genai_rewriting/android/src/main/java/com/google_mlkit_genai_rewriting/GoogleMlKitGenaiRewritingPlugin.java deleted file mode 100644 index 8a68af70..00000000 --- a/packages/google_mlkit_genai_rewriting/android/src/main/java/com/google_mlkit_genai_rewriting/GoogleMlKitGenaiRewritingPlugin.java +++ /dev/null @@ -1,22 +0,0 @@ -package com.google_mlkit_genai_rewriting; - -import androidx.annotation.NonNull; - -import io.flutter.embedding.engine.plugins.FlutterPlugin; -import io.flutter.plugin.common.MethodChannel; - -public class GoogleMlKitGenaiRewritingPlugin implements FlutterPlugin { - private MethodChannel channel; - private static final String channelName = "google_mlkit_genai_rewriting"; - - @Override - public void onAttachedToEngine(@NonNull FlutterPluginBinding flutterPluginBinding) { - channel = new MethodChannel(flutterPluginBinding.getBinaryMessenger(), channelName); - channel.setMethodCallHandler(new Rewriter(flutterPluginBinding.getApplicationContext())); - } - - @Override - public void onDetachedFromEngine(@NonNull FlutterPluginBinding binding) { - channel.setMethodCallHandler(null); - } -} diff --git a/packages/google_mlkit_genai_rewriting/android/src/main/java/com/google_mlkit_genai_rewriting/Rewriter.java b/packages/google_mlkit_genai_rewriting/android/src/main/java/com/google_mlkit_genai_rewriting/Rewriter.java deleted file mode 100644 index 52530950..00000000 --- a/packages/google_mlkit_genai_rewriting/android/src/main/java/com/google_mlkit_genai_rewriting/Rewriter.java +++ /dev/null @@ -1,189 +0,0 @@ -package com.google_mlkit_genai_rewriting; - -import android.content.Context; - -import androidx.annotation.NonNull; - -import com.google.mlkit.genai.rewriting.Rewriting; -import com.google.mlkit.genai.rewriting.RewritingRequest; -import com.google.mlkit.genai.rewriting.RewriterOptions; - -import com.google.common.util.concurrent.Futures; -import com.google.common.util.concurrent.ListenableFuture; -import com.google.common.util.concurrent.FutureCallback; - -import java.util.ArrayList; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.concurrent.Executor; -import java.util.concurrent.Executors; - -import io.flutter.plugin.common.MethodCall; -import io.flutter.plugin.common.MethodChannel; - -public class Rewriter implements MethodChannel.MethodCallHandler { - private static final String CHECK_FEATURE_STATUS = "genai#checkFeatureStatus"; - private static final String DOWNLOAD_FEATURE = "genai#downloadFeature"; - private static final String RUN_INFERENCE = "genai#runInference"; - private static final String RUN_INFERENCE_STREAMING = "genai#runInferenceStreaming"; - private static final String CLOSE = "genai#closeRewriter"; - - private final Context context; - private final Map instances = new HashMap<>(); - private final Executor executor = Executors.newSingleThreadExecutor(); - - public Rewriter(Context context) { - this.context = context; - } - - @Override - public void onMethodCall(@NonNull MethodCall call, @NonNull MethodChannel.Result result) { - String method = call.method; - switch (method) { - case CHECK_FEATURE_STATUS: - checkFeatureStatus(call, result); - break; - case DOWNLOAD_FEATURE: - downloadFeature(call, result); - break; - case RUN_INFERENCE: - runInference(call, result); - break; - case RUN_INFERENCE_STREAMING: - runInferenceStreaming(call, result); - break; - case CLOSE: - closeRewriter(call); - result.success(null); - break; - default: - result.notImplemented(); - break; - } - } - - private com.google.mlkit.genai.rewriting.Rewriter initialize(MethodCall call) { - // Use basic RewriterOptions builder - API structure may vary - RewriterOptions options = RewriterOptions.builder(context).build(); - return Rewriting.getClient(options); - } - - private void checkFeatureStatus(MethodCall call, MethodChannel.Result result) { - String id = call.argument("id"); - com.google.mlkit.genai.rewriting.Rewriter rewriter = instances.get(id); - if (rewriter == null) { - rewriter = initialize(call); - instances.put(id, rewriter); - } - - ListenableFuture future = rewriter.checkFeatureStatus(); - Futures.addCallback(future, new FutureCallback() { - @Override - public void onSuccess(Integer status) { - int statusValue; - if (status == com.google.mlkit.genai.common.FeatureStatus.UNAVAILABLE) { - statusValue = 0; - } else if (status == com.google.mlkit.genai.common.FeatureStatus.DOWNLOADABLE) { - statusValue = 1; - } else if (status == com.google.mlkit.genai.common.FeatureStatus.DOWNLOADING) { - statusValue = 2; - } else if (status == com.google.mlkit.genai.common.FeatureStatus.AVAILABLE) { - statusValue = 3; - } else { - statusValue = 0; - } - result.success(statusValue); - } - - @Override - public void onFailure(Throwable e) { - result.error("RewriterError", e.toString(), null); - } - }, executor); - } - - private void downloadFeature(MethodCall call, MethodChannel.Result result) { - String id = call.argument("id"); - com.google.mlkit.genai.rewriting.Rewriter rewriter = instances.get(id); - if (rewriter == null) { - rewriter = initialize(call); - instances.put(id, rewriter); - } - - rewriter.downloadFeature(new com.google.mlkit.genai.common.DownloadCallback() { - @Override - public void onDownloadStarted(long bytesToDownload) { - // Handle download started - } - - @Override - public void onDownloadFailed(com.google.mlkit.genai.common.GenAiException e) { - result.error("DownloadError", e.toString(), null); - } - - @Override - public void onDownloadProgress(long totalBytesDownloaded) { - // Handle download progress - } - - @Override - public void onDownloadCompleted() { - result.success(null); - } - }); - } - - private void runInference(MethodCall call, MethodChannel.Result result) { - String id = call.argument("id"); - String text = call.argument("text"); - com.google.mlkit.genai.rewriting.Rewriter rewriter = instances.get(id); - if (rewriter == null) { - rewriter = initialize(call); - instances.put(id, rewriter); - } - - RewritingRequest request = RewritingRequest.builder(text).build(); - ListenableFuture future = rewriter.runInference(request); - Futures.addCallback(future, new FutureCallback() { - @Override - public void onSuccess(com.google.mlkit.genai.rewriting.RewritingResult rewritingResult) { - Map response = new HashMap<>(); - // RewritingResult may have getText() or getRewrittenText() method - // Use reflection to find the correct method - try { - String rewrittenText = (String) rewritingResult.getClass().getMethod("getText").invoke(rewritingResult); - response.put("text", rewrittenText); - } catch (Exception e1) { - try { - // Try getRewrittenText() if getText() doesn't exist - String rewrittenText = (String) rewritingResult.getClass().getMethod("getRewrittenText").invoke(rewritingResult); - response.put("text", rewrittenText); - } catch (Exception e2) { - // If both fail, return empty string - response.put("text", ""); - } - } - result.success(response); - } - - @Override - public void onFailure(Throwable e) { - result.error("InferenceError", e.toString(), null); - } - }, executor); - } - - private void runInferenceStreaming(MethodCall call, MethodChannel.Result result) { - // Streaming implementation would require EventChannel - result.notImplemented(); - } - - private void closeRewriter(MethodCall call) { - String id = call.argument("id"); - com.google.mlkit.genai.rewriting.Rewriter rewriter = instances.get(id); - if (rewriter == null) return; - rewriter.close(); - instances.remove(id); - } -} diff --git a/packages/google_mlkit_genai_rewriting/android/src/main/kotlin/com/google_mlkit_genai_rewriting/GoogleMlKitGenaiRewritingPlugin.kt b/packages/google_mlkit_genai_rewriting/android/src/main/kotlin/com/google_mlkit_genai_rewriting/GoogleMlKitGenaiRewritingPlugin.kt new file mode 100644 index 00000000..019c4e0a --- /dev/null +++ b/packages/google_mlkit_genai_rewriting/android/src/main/kotlin/com/google_mlkit_genai_rewriting/GoogleMlKitGenaiRewritingPlugin.kt @@ -0,0 +1,21 @@ +package com.google_mlkit_genai_rewriting + +import io.flutter.embedding.engine.plugins.FlutterPlugin +import io.flutter.plugin.common.MethodChannel + +class GoogleMlKitGenaiRewritingPlugin : FlutterPlugin { + private lateinit var channel: MethodChannel + + companion object { + private const val CHANNEL_NAME = "google_mlkit_genai_rewriting" + } + + override fun onAttachedToEngine(flutterPluginBinding: FlutterPlugin.FlutterPluginBinding) { + channel = MethodChannel(flutterPluginBinding.binaryMessenger, CHANNEL_NAME) + channel.setMethodCallHandler(Rewriter(flutterPluginBinding.applicationContext)) + } + + override fun onDetachedFromEngine(binding: FlutterPlugin.FlutterPluginBinding) { + channel.setMethodCallHandler(null) + } +} diff --git a/packages/google_mlkit_genai_rewriting/android/src/main/kotlin/com/google_mlkit_genai_rewriting/Rewriter.kt b/packages/google_mlkit_genai_rewriting/android/src/main/kotlin/com/google_mlkit_genai_rewriting/Rewriter.kt new file mode 100644 index 00000000..d4dded89 --- /dev/null +++ b/packages/google_mlkit_genai_rewriting/android/src/main/kotlin/com/google_mlkit_genai_rewriting/Rewriter.kt @@ -0,0 +1,161 @@ +package com.google_mlkit_genai_rewriting + +import android.content.Context +import com.google.common.util.concurrent.FutureCallback +import com.google.common.util.concurrent.Futures +import com.google.mlkit.genai.common.DownloadCallback +import com.google.mlkit.genai.common.FeatureStatus +import com.google.mlkit.genai.common.GenAiException +import com.google.mlkit.genai.rewriting.RewriterOptions +import com.google.mlkit.genai.rewriting.Rewriting +import com.google.mlkit.genai.rewriting.RewritingRequest +import com.google.mlkit.genai.rewriting.RewritingResult +import io.flutter.plugin.common.MethodCall +import io.flutter.plugin.common.MethodChannel +import java.util.concurrent.Executors + +class Rewriter( + private val context: Context, +) : MethodChannel.MethodCallHandler { + private val instances = mutableMapOf() + private val executor = Executors.newSingleThreadExecutor() + + companion object { + private const val CHECK_FEATURE_STATUS = "genai#checkFeatureStatus" + private const val DOWNLOAD_FEATURE = "genai#downloadFeature" + private const val RUN_INFERENCE = "genai#runInference" + private const val RUN_INFERENCE_STREAMING = "genai#runInferenceStreaming" + private const val CLOSE = "genai#closeRewriter" + } + + override fun onMethodCall( + call: MethodCall, + result: MethodChannel.Result, + ) { + when (call.method) { + CHECK_FEATURE_STATUS -> { + checkFeatureStatus(call, result) + } + + DOWNLOAD_FEATURE -> { + downloadFeature(call, result) + } + + RUN_INFERENCE -> { + runInference(call, result) + } + + RUN_INFERENCE_STREAMING -> { + result.notImplemented() + } + + CLOSE -> { + closeRewriter(call) + result.success(null) + } + + else -> { + result.notImplemented() + } + } + } + + private fun initialize(call: MethodCall): com.google.mlkit.genai.rewriting.Rewriter { + val options = RewriterOptions.builder(context).build() + return Rewriting.getClient(options) + } + + private fun checkFeatureStatus( + call: MethodCall, + result: MethodChannel.Result, + ) { + val id = call.argument("id") ?: return + val rewriter = instances.getOrPut(id) { initialize(call) } + + val future = rewriter.checkFeatureStatus() + Futures.addCallback( + future, + object : FutureCallback { + override fun onSuccess(status: Int?) { + val statusValue = + when (status) { + FeatureStatus.UNAVAILABLE -> 0 + FeatureStatus.DOWNLOADABLE -> 1 + FeatureStatus.DOWNLOADING -> 2 + FeatureStatus.AVAILABLE -> 3 + else -> 0 + } + result.success(statusValue) + } + + override fun onFailure(e: Throwable) { + result.error("RewriterError", e.toString(), null) + } + }, + executor, + ) + } + + private fun downloadFeature( + call: MethodCall, + result: MethodChannel.Result, + ) { + val id = call.argument("id") ?: return + val rewriter = instances.getOrPut(id) { initialize(call) } + + rewriter.downloadFeature( + object : DownloadCallback { + override fun onDownloadStarted(bytesToDownload: Long) {} + + override fun onDownloadFailed(e: GenAiException) { + result.error("DownloadError", e.toString(), null) + } + + override fun onDownloadProgress(totalBytesDownloaded: Long) {} + + override fun onDownloadCompleted() { + result.success(null) + } + }, + ) + } + + private fun runInference( + call: MethodCall, + result: MethodChannel.Result, + ) { + val id = call.argument("id") ?: return + val text = call.argument("text") ?: return + val rewriter = instances.getOrPut(id) { initialize(call) } + + val request = RewritingRequest.builder(text).build() + val future = rewriter.runInference(request) + + Futures.addCallback( + future, + object : FutureCallback { + override fun onSuccess(rewritingResult: RewritingResult?) { + val rewrittenText = + rewritingResult?.let { + runCatching { it.javaClass.getMethod("getText").invoke(it) as? String } + .getOrNull() + ?: runCatching { it.javaClass.getMethod("getRewrittenText").invoke(it) as? String } + .getOrNull() + ?: "" + } ?: "" + result.success(mapOf("text" to rewrittenText)) + } + + override fun onFailure(e: Throwable) { + result.error("InferenceError", e.toString(), null) + } + }, + executor, + ) + } + + private fun closeRewriter(call: MethodCall) { + val id = call.argument("id") ?: return + instances.remove(id)?.close() + } +} diff --git a/packages/google_mlkit_genai_speech_recognition/android/build.gradle b/packages/google_mlkit_genai_speech_recognition/android/build.gradle index 4e38f3b3..1ed6c5f3 100644 --- a/packages/google_mlkit_genai_speech_recognition/android/build.gradle +++ b/packages/google_mlkit_genai_speech_recognition/android/build.gradle @@ -2,6 +2,7 @@ group = "com.google_mlkit_genai_speech_recognition" version = "1.0" buildscript { + ext.kotlin_version = "2.2.20" repositories { google() mavenCentral() @@ -9,6 +10,7 @@ buildscript { dependencies { classpath("com.android.tools.build:gradle:8.13.0") + classpath('org.jetbrains.kotlin:kotlin-gradle-plugin:2.2.20') } } @@ -20,6 +22,7 @@ rootProject.allprojects { } apply plugin: "com.android.library" +apply plugin: "kotlin-android" android { namespace = "com.google_mlkit_genai_speech_recognition" @@ -31,6 +34,14 @@ android { targetCompatibility = JavaVersion.VERSION_11 } + kotlinOptions { + jvmTarget = '11' + } + + sourceSets { + main.java.srcDirs += "src/main/kotlin" + } + defaultConfig { minSdk = 26 } diff --git a/packages/google_mlkit_genai_speech_recognition/android/gradle/wrapper/gradle-wrapper.properties b/packages/google_mlkit_genai_speech_recognition/android/gradle/wrapper/gradle-wrapper.properties new file mode 100644 index 00000000..128196a7 --- /dev/null +++ b/packages/google_mlkit_genai_speech_recognition/android/gradle/wrapper/gradle-wrapper.properties @@ -0,0 +1,7 @@ +distributionBase=GRADLE_USER_HOME +distributionPath=wrapper/dists +distributionUrl=https\://services.gradle.org/distributions/gradle-9.0-milestone-1-bin.zip +networkTimeout=10000 +validateDistributionUrl=true +zipStoreBase=GRADLE_USER_HOME +zipStorePath=wrapper/dists diff --git a/packages/google_mlkit_genai_speech_recognition/android/src/main/java/com/google_mlkit_genai_speech_recognition/GoogleMlKitGenaiSpeechRecognitionPlugin.java b/packages/google_mlkit_genai_speech_recognition/android/src/main/java/com/google_mlkit_genai_speech_recognition/GoogleMlKitGenaiSpeechRecognitionPlugin.java deleted file mode 100644 index 8680d490..00000000 --- a/packages/google_mlkit_genai_speech_recognition/android/src/main/java/com/google_mlkit_genai_speech_recognition/GoogleMlKitGenaiSpeechRecognitionPlugin.java +++ /dev/null @@ -1,22 +0,0 @@ -package com.google_mlkit_genai_speech_recognition; - -import androidx.annotation.NonNull; - -import io.flutter.embedding.engine.plugins.FlutterPlugin; -import io.flutter.plugin.common.MethodChannel; - -public class GoogleMlKitGenaiSpeechRecognitionPlugin implements FlutterPlugin { - private MethodChannel channel; - private static final String channelName = "google_mlkit_genai_speech_recognition"; - - @Override - public void onAttachedToEngine(@NonNull FlutterPluginBinding flutterPluginBinding) { - channel = new MethodChannel(flutterPluginBinding.getBinaryMessenger(), channelName); - channel.setMethodCallHandler(new SpeechRecognizer(flutterPluginBinding.getApplicationContext())); - } - - @Override - public void onDetachedFromEngine(@NonNull FlutterPluginBinding binding) { - channel.setMethodCallHandler(null); - } -} diff --git a/packages/google_mlkit_genai_speech_recognition/android/src/main/java/com/google_mlkit_genai_speech_recognition/SpeechRecognizer.java b/packages/google_mlkit_genai_speech_recognition/android/src/main/java/com/google_mlkit_genai_speech_recognition/SpeechRecognizer.java deleted file mode 100644 index 588d8251..00000000 --- a/packages/google_mlkit_genai_speech_recognition/android/src/main/java/com/google_mlkit_genai_speech_recognition/SpeechRecognizer.java +++ /dev/null @@ -1,133 +0,0 @@ -package com.google_mlkit_genai_speech_recognition; - -import android.content.Context; - -import androidx.annotation.NonNull; - - -import com.google.common.util.concurrent.Futures; -import com.google.common.util.concurrent.ListenableFuture; -import com.google.common.util.concurrent.FutureCallback; - -import java.util.HashMap; -import java.util.Map; -import java.util.concurrent.Executor; -import java.util.concurrent.Executors; - -import io.flutter.plugin.common.MethodCall; -import io.flutter.plugin.common.MethodChannel; - -public class SpeechRecognizer implements MethodChannel.MethodCallHandler { - private static final String CHECK_STATUS = "genai#checkStatus"; - private static final String START_RECOGNITION = "genai#startRecognition"; - private static final String STOP_RECOGNITION = "genai#stopRecognition"; - private static final String CLOSE = "genai#closeSpeechRecognizer"; - - private final Context context; - private final Map instances = new HashMap<>(); - private final Executor executor = Executors.newSingleThreadExecutor(); - - public SpeechRecognizer(Context context) { - this.context = context; - } - - @Override - public void onMethodCall(@NonNull MethodCall call, @NonNull MethodChannel.Result result) { - String method = call.method; - switch (method) { - case CHECK_STATUS: - checkStatus(call, result); - break; - case START_RECOGNITION: - startRecognition(call, result); - break; - case STOP_RECOGNITION: - stopRecognition(call, result); - break; - case CLOSE: - closeSpeechRecognizer(call); - result.success(null); - break; - default: - result.notImplemented(); - break; - } - } - - private Object initialize(MethodCall call) { - // Speech Recognition API initialization - structure may vary - // This is a placeholder implementation using reflection - try { - Object optionsBuilder = Class.forName("com.google.mlkit.genai.speechrecognition.SpeechRecognizerOptions") - .getMethod("builder", Context.class).invoke(null, context); - Object options = optionsBuilder.getClass().getMethod("build").invoke(optionsBuilder); - Object speechRecognizer = Class.forName("com.google.mlkit.genai.speechrecognition.SpeechRecognition") - .getMethod("getClient", options.getClass()).invoke(null, options); - return speechRecognizer; - } catch (Exception e) { - // If reflection fails, return a placeholder - return new Object(); - } - } - - private void checkStatus(MethodCall call, MethodChannel.Result result) { - String id = call.argument("id"); - Object speechRecognizer = instances.get(id); - if (speechRecognizer == null) { - speechRecognizer = initialize(call); - instances.put(id, speechRecognizer); - } - - try { - // Speech Recognition uses checkStatus() instead of checkFeatureStatus() - ListenableFuture future = (ListenableFuture) speechRecognizer.getClass() - .getMethod("checkStatus").invoke(speechRecognizer); - Futures.addCallback(future, new FutureCallback() { - @Override - public void onSuccess(Integer status) { - int statusValue; - if (status == com.google.mlkit.genai.common.FeatureStatus.UNAVAILABLE) { - statusValue = 0; - } else if (status == com.google.mlkit.genai.common.FeatureStatus.DOWNLOADABLE) { - statusValue = 1; - } else if (status == com.google.mlkit.genai.common.FeatureStatus.DOWNLOADING) { - statusValue = 2; - } else if (status == com.google.mlkit.genai.common.FeatureStatus.AVAILABLE) { - statusValue = 3; - } else { - statusValue = 0; - } - result.success(statusValue); - } - - @Override - public void onFailure(Throwable e) { - result.error("SpeechRecognizerError", e.toString(), null); - } - }, executor); - } catch (Exception e) { - result.error("SpeechRecognizerError", "Failed to check status: " + e.toString(), null); - } - } - - private void startRecognition(MethodCall call, MethodChannel.Result result) { - // Speech Recognition uses streaming - would need EventChannel - result.notImplemented(); - } - - private void stopRecognition(MethodCall call, MethodChannel.Result result) { - result.notImplemented(); - } - - private void closeSpeechRecognizer(MethodCall call) { - String id = call.argument("id"); - Object speechRecognizer = instances.get(id); - if (speechRecognizer == null) return; - try { - speechRecognizer.getClass().getMethod("close").invoke(speechRecognizer); - } catch (Exception e) { - // If close() doesn't exist, just remove from instances - } - instances.remove(id); - } -} diff --git a/packages/google_mlkit_genai_speech_recognition/android/src/main/kotlin/com/google_mlkit_genai_speech_recognition/GoogleMlKitGenaiSpeechRecognitionPlugin.kt b/packages/google_mlkit_genai_speech_recognition/android/src/main/kotlin/com/google_mlkit_genai_speech_recognition/GoogleMlKitGenaiSpeechRecognitionPlugin.kt new file mode 100644 index 00000000..0f2f0ab9 --- /dev/null +++ b/packages/google_mlkit_genai_speech_recognition/android/src/main/kotlin/com/google_mlkit_genai_speech_recognition/GoogleMlKitGenaiSpeechRecognitionPlugin.kt @@ -0,0 +1,21 @@ +package com.google_mlkit_genai_speech_recognition + +import io.flutter.embedding.engine.plugins.FlutterPlugin +import io.flutter.plugin.common.MethodChannel + +class GoogleMlKitGenaiSpeechRecognitionPlugin : FlutterPlugin { + private lateinit var channel: MethodChannel + + companion object { + private const val CHANNEL_NAME = "google_mlkit_genai_speech_recognition" + } + + override fun onAttachedToEngine(flutterPluginBinding: FlutterPlugin.FlutterPluginBinding) { + channel = MethodChannel(flutterPluginBinding.binaryMessenger, CHANNEL_NAME) + channel.setMethodCallHandler(SpeechRecognizer(flutterPluginBinding.applicationContext)) + } + + override fun onDetachedFromEngine(p0: FlutterPlugin.FlutterPluginBinding) { + channel.setMethodCallHandler(null) + } +} diff --git a/packages/google_mlkit_genai_speech_recognition/android/src/main/kotlin/com/google_mlkit_genai_speech_recognition/SpeechRecognizer.kt b/packages/google_mlkit_genai_speech_recognition/android/src/main/kotlin/com/google_mlkit_genai_speech_recognition/SpeechRecognizer.kt new file mode 100644 index 00000000..38db401f --- /dev/null +++ b/packages/google_mlkit_genai_speech_recognition/android/src/main/kotlin/com/google_mlkit_genai_speech_recognition/SpeechRecognizer.kt @@ -0,0 +1,112 @@ +package com.google_mlkit_genai_speech_recognition + +import android.content.Context +import com.google.common.util.concurrent.FutureCallback +import com.google.common.util.concurrent.Futures +import com.google.common.util.concurrent.ListenableFuture +import com.google.mlkit.genai.common.FeatureStatus +import io.flutter.plugin.common.MethodCall +import io.flutter.plugin.common.MethodChannel +import java.util.concurrent.Executors + +class SpeechRecognizer( + private val context: Context, +) : MethodChannel.MethodCallHandler { + private val instances = mutableMapOf() + private val executor = Executors.newSingleThreadExecutor() + + companion object { + private const val CHECK_STATUS = "genai#checkStatus" + private const val START_RECOGNITION = "genai#startRecognition" + private const val STOP_RECOGNITION = "genai#stopRecognition" + private const val CLOSE = "genai#closeSpeechRecognizer" + } + + override fun onMethodCall( + call: MethodCall, + result: MethodChannel.Result, + ) { + when (call.method) { + CHECK_STATUS -> { + checkStatus(call, result) + } + + START_RECOGNITION -> { + result.notImplemented() + } + + STOP_RECOGNITION -> { + result.notImplemented() + } + + CLOSE -> { + closeSpeechRecognizer(call) + result.success(null) + } + + else -> { + result.notImplemented() + } + } + } + + private fun initialize(call: MethodCall): Any = + runCatching { + val optionsBuilder = + Class + .forName("com.google.mlkit.genai.speechrecognition.SpeechRecognizerOptions") + .getMethod("builder", Context::class.java) + .invoke(null, context)!! + val options = optionsBuilder.javaClass.getMethod("build").invoke(optionsBuilder)!! + Class + .forName("com.google.mlkit.genai.speechrecognition.SpeechRecognition") + .getMethod("getClient", options.javaClass) + .invoke(null, options)!! + }.getOrDefault(Any()) + + private fun checkStatus( + call: MethodCall, + result: MethodChannel.Result, + ) { + val id = call.argument("id") ?: return + val speechRecognizer = instances.getOrPut(id) { initialize(call) } + + runCatching { + @Suppress("UNCHECKED_CAST") + val future = + speechRecognizer.javaClass + .getMethod("checkStatus") + .invoke(speechRecognizer) as ListenableFuture + + Futures.addCallback( + future, + object : FutureCallback { + override fun onSuccess(status: Int?) { + val statusValue = + when (status) { + FeatureStatus.UNAVAILABLE -> 0 + FeatureStatus.DOWNLOADABLE -> 1 + FeatureStatus.DOWNLOADING -> 2 + FeatureStatus.AVAILABLE -> 3 + else -> 0 + } + result.success(statusValue) + } + + override fun onFailure(e: Throwable) { + result.error("SpeechRecognizerError", e.toString(), null) + } + }, + executor, + ) + }.onFailure { e -> + result.error("SpeechRecognizerError", "Failed to check status: $e", null) + } + } + + private fun closeSpeechRecognizer(call: MethodCall) { + val id = call.argument("id") ?: return + val speechRecognizer = instances.remove(id) ?: return + runCatching { speechRecognizer.javaClass.getMethod("close").invoke(speechRecognizer) } + } +} diff --git a/packages/google_mlkit_genai_summarization/android/build.gradle b/packages/google_mlkit_genai_summarization/android/build.gradle index c26b9bec..ddaf8a29 100644 --- a/packages/google_mlkit_genai_summarization/android/build.gradle +++ b/packages/google_mlkit_genai_summarization/android/build.gradle @@ -2,6 +2,7 @@ group = "com.google_mlkit_genai_summarization" version = "1.0" buildscript { + ext.kotlin_version = "2.2.20" repositories { google() mavenCentral() @@ -9,6 +10,7 @@ buildscript { dependencies { classpath("com.android.tools.build:gradle:8.13.0") + classpath('org.jetbrains.kotlin:kotlin-gradle-plugin:2.2.20') } } @@ -20,6 +22,7 @@ rootProject.allprojects { } apply plugin: "com.android.library" +apply plugin: "kotlin-android" android { namespace = "com.google_mlkit_genai_summarization" @@ -31,6 +34,14 @@ android { targetCompatibility = JavaVersion.VERSION_11 } + kotlinOptions { + jvmTarget = '11' + } + + sourceSets { + main.java.srcDirs += "src/main/kotlin" + } + defaultConfig { minSdk = 26 } diff --git a/packages/google_mlkit_genai_summarization/android/gradle/wrapper/gradle-wrapper.properties b/packages/google_mlkit_genai_summarization/android/gradle/wrapper/gradle-wrapper.properties new file mode 100644 index 00000000..128196a7 --- /dev/null +++ b/packages/google_mlkit_genai_summarization/android/gradle/wrapper/gradle-wrapper.properties @@ -0,0 +1,7 @@ +distributionBase=GRADLE_USER_HOME +distributionPath=wrapper/dists +distributionUrl=https\://services.gradle.org/distributions/gradle-9.0-milestone-1-bin.zip +networkTimeout=10000 +validateDistributionUrl=true +zipStoreBase=GRADLE_USER_HOME +zipStorePath=wrapper/dists diff --git a/packages/google_mlkit_genai_summarization/android/src/main/java/com/google_mlkit_genai_summarization/GoogleMlKitGenaiSummarizationPlugin.java b/packages/google_mlkit_genai_summarization/android/src/main/java/com/google_mlkit_genai_summarization/GoogleMlKitGenaiSummarizationPlugin.java deleted file mode 100644 index c53ef564..00000000 --- a/packages/google_mlkit_genai_summarization/android/src/main/java/com/google_mlkit_genai_summarization/GoogleMlKitGenaiSummarizationPlugin.java +++ /dev/null @@ -1,22 +0,0 @@ -package com.google_mlkit_genai_summarization; - -import androidx.annotation.NonNull; - -import io.flutter.embedding.engine.plugins.FlutterPlugin; -import io.flutter.plugin.common.MethodChannel; - -public class GoogleMlKitGenaiSummarizationPlugin implements FlutterPlugin { - private MethodChannel channel; - private static final String channelName = "google_mlkit_genai_summarization"; - - @Override - public void onAttachedToEngine(@NonNull FlutterPluginBinding flutterPluginBinding) { - channel = new MethodChannel(flutterPluginBinding.getBinaryMessenger(), channelName); - channel.setMethodCallHandler(new Summarizer(flutterPluginBinding.getApplicationContext())); - } - - @Override - public void onDetachedFromEngine(@NonNull FlutterPluginBinding binding) { - channel.setMethodCallHandler(null); - } -} diff --git a/packages/google_mlkit_genai_summarization/android/src/main/java/com/google_mlkit_genai_summarization/Summarizer.java b/packages/google_mlkit_genai_summarization/android/src/main/java/com/google_mlkit_genai_summarization/Summarizer.java deleted file mode 100644 index e2e68e42..00000000 --- a/packages/google_mlkit_genai_summarization/android/src/main/java/com/google_mlkit_genai_summarization/Summarizer.java +++ /dev/null @@ -1,174 +0,0 @@ -package com.google_mlkit_genai_summarization; - -import android.content.Context; - -import androidx.annotation.NonNull; - -import com.google.mlkit.genai.summarization.Summarization; -import com.google.mlkit.genai.summarization.SummarizationRequest; -import com.google.mlkit.genai.summarization.SummarizerOptions; - -import com.google.common.util.concurrent.Futures; -import com.google.common.util.concurrent.ListenableFuture; -import com.google.common.util.concurrent.FutureCallback; - -import java.util.HashMap; -import java.util.Map; -import java.util.concurrent.Executor; -import java.util.concurrent.Executors; - -import io.flutter.plugin.common.MethodCall; -import io.flutter.plugin.common.MethodChannel; - -public class Summarizer implements MethodChannel.MethodCallHandler { - private static final String CHECK_FEATURE_STATUS = "genai#checkFeatureStatus"; - private static final String DOWNLOAD_FEATURE = "genai#downloadFeature"; - private static final String RUN_INFERENCE = "genai#runInference"; - private static final String RUN_INFERENCE_STREAMING = "genai#runInferenceStreaming"; - private static final String CLOSE = "genai#closeSummarizer"; - - private final Context context; - private final Map instances = new HashMap<>(); - private final Executor executor = Executors.newSingleThreadExecutor(); - - public Summarizer(Context context) { - this.context = context; - } - - @Override - public void onMethodCall(@NonNull MethodCall call, @NonNull MethodChannel.Result result) { - String method = call.method; - switch (method) { - case CHECK_FEATURE_STATUS: - checkFeatureStatus(call, result); - break; - case DOWNLOAD_FEATURE: - downloadFeature(call, result); - break; - case RUN_INFERENCE: - runInference(call, result); - break; - case RUN_INFERENCE_STREAMING: - runInferenceStreaming(call, result); - break; - case CLOSE: - closeSummarizer(call); - result.success(null); - break; - default: - result.notImplemented(); - break; - } - } - - private com.google.mlkit.genai.summarization.Summarizer initialize(MethodCall call) { - // Use basic SummarizerOptions builder - API structure may vary - SummarizerOptions options = SummarizerOptions.builder(context).build(); - return Summarization.getClient(options); - } - - private void checkFeatureStatus(MethodCall call, MethodChannel.Result result) { - String id = call.argument("id"); - com.google.mlkit.genai.summarization.Summarizer summarizer = instances.get(id); - if (summarizer == null) { - summarizer = initialize(call); - instances.put(id, summarizer); - } - - ListenableFuture future = summarizer.checkFeatureStatus(); - Futures.addCallback(future, new FutureCallback() { - @Override - public void onSuccess(Integer status) { - int statusValue; - if (status == com.google.mlkit.genai.common.FeatureStatus.UNAVAILABLE) { - statusValue = 0; - } else if (status == com.google.mlkit.genai.common.FeatureStatus.DOWNLOADABLE) { - statusValue = 1; - } else if (status == com.google.mlkit.genai.common.FeatureStatus.DOWNLOADING) { - statusValue = 2; - } else if (status == com.google.mlkit.genai.common.FeatureStatus.AVAILABLE) { - statusValue = 3; - } else { - statusValue = 0; - } - result.success(statusValue); - } - - @Override - public void onFailure(Throwable e) { - result.error("SummarizerError", e.toString(), null); - } - }, executor); - } - - private void downloadFeature(MethodCall call, MethodChannel.Result result) { - String id = call.argument("id"); - com.google.mlkit.genai.summarization.Summarizer summarizer = instances.get(id); - if (summarizer == null) { - summarizer = initialize(call); - instances.put(id, summarizer); - } - - summarizer.downloadFeature(new com.google.mlkit.genai.common.DownloadCallback() { - @Override - public void onDownloadStarted(long bytesToDownload) { - // Handle download started - } - - @Override - public void onDownloadFailed(com.google.mlkit.genai.common.GenAiException e) { - result.error("DownloadError", e.toString(), null); - } - - @Override - public void onDownloadProgress(long totalBytesDownloaded) { - // Handle download progress - } - - @Override - public void onDownloadCompleted() { - result.success(null); - } - }); - } - - private void runInference(MethodCall call, MethodChannel.Result result) { - String id = call.argument("id"); - String text = call.argument("text"); - com.google.mlkit.genai.summarization.Summarizer summarizer = instances.get(id); - if (summarizer == null) { - summarizer = initialize(call); - instances.put(id, summarizer); - } - - SummarizationRequest request = SummarizationRequest.builder(text).build(); - ListenableFuture future = summarizer.runInference(request); - Futures.addCallback(future, new FutureCallback() { - @Override - public void onSuccess(com.google.mlkit.genai.summarization.SummarizationResult summarizationResult) { - Map resultMap = new HashMap<>(); - resultMap.put("summary", summarizationResult.getSummary()); - result.success(resultMap); - } - - @Override - public void onFailure(Throwable e) { - result.error("InferenceError", e.toString(), null); - } - }, executor); - } - - private void runInferenceStreaming(MethodCall call, MethodChannel.Result result) { - // Streaming implementation would require EventChannel - // For now, this is a placeholder - result.notImplemented(); - } - - private void closeSummarizer(MethodCall call) { - String id = call.argument("id"); - com.google.mlkit.genai.summarization.Summarizer summarizer = instances.get(id); - if (summarizer == null) return; - summarizer.close(); - instances.remove(id); - } -} diff --git a/packages/google_mlkit_genai_summarization/android/src/main/kotlin/com/google_mlkit_genai_summarization/GoogleMlKitGenaiSummarizationPlugin.kt b/packages/google_mlkit_genai_summarization/android/src/main/kotlin/com/google_mlkit_genai_summarization/GoogleMlKitGenaiSummarizationPlugin.kt new file mode 100644 index 00000000..68505415 --- /dev/null +++ b/packages/google_mlkit_genai_summarization/android/src/main/kotlin/com/google_mlkit_genai_summarization/GoogleMlKitGenaiSummarizationPlugin.kt @@ -0,0 +1,21 @@ +package com.google_mlkit_genai_summarization + +import io.flutter.embedding.engine.plugins.FlutterPlugin +import io.flutter.plugin.common.MethodChannel + +class GoogleMlKitGenaiSummarizationPlugin : FlutterPlugin { + private lateinit var channel: MethodChannel + + companion object { + private const val CHANNEL_NAME = "google_mlkit_genai_summarization" + } + + override fun onAttachedToEngine(flutterPluginBinding: FlutterPlugin.FlutterPluginBinding) { + channel = MethodChannel(flutterPluginBinding.binaryMessenger, CHANNEL_NAME) + channel.setMethodCallHandler(Summarizer(flutterPluginBinding.applicationContext)) + } + + override fun onDetachedFromEngine(binding: FlutterPlugin.FlutterPluginBinding) { + channel.setMethodCallHandler(null) + } +} diff --git a/packages/google_mlkit_genai_summarization/android/src/main/kotlin/com/google_mlkit_genai_summarization/Summarizer.kt b/packages/google_mlkit_genai_summarization/android/src/main/kotlin/com/google_mlkit_genai_summarization/Summarizer.kt new file mode 100644 index 00000000..b157f3a4 --- /dev/null +++ b/packages/google_mlkit_genai_summarization/android/src/main/kotlin/com/google_mlkit_genai_summarization/Summarizer.kt @@ -0,0 +1,153 @@ +package com.google_mlkit_genai_summarization + +import android.content.Context +import com.google.common.util.concurrent.FutureCallback +import com.google.common.util.concurrent.Futures +import com.google.mlkit.genai.common.DownloadCallback +import com.google.mlkit.genai.common.FeatureStatus +import com.google.mlkit.genai.common.GenAiException +import com.google.mlkit.genai.summarization.Summarization +import com.google.mlkit.genai.summarization.SummarizationRequest +import com.google.mlkit.genai.summarization.SummarizationResult +import com.google.mlkit.genai.summarization.SummarizerOptions +import io.flutter.plugin.common.MethodCall +import io.flutter.plugin.common.MethodChannel +import java.util.concurrent.Executors + +class Summarizer( + private val context: Context, +) : MethodChannel.MethodCallHandler { + private val instances = mutableMapOf() + private val executor = Executors.newSingleThreadExecutor() + + companion object { + private const val CHECK_FEATURE_STATUS = "genai#checkFeatureStatus" + private const val DOWNLOAD_FEATURE = "genai#downloadFeature" + private const val RUN_INFERENCE = "genai#runInference" + private const val RUN_INFERENCE_STREAMING = "genai#runInferenceStreaming" + private const val CLOSE = "genai#closeSummarizer" + } + + override fun onMethodCall( + call: MethodCall, + result: MethodChannel.Result, + ) { + when (call.method) { + CHECK_FEATURE_STATUS -> { + checkFeatureStatus(call, result) + } + + DOWNLOAD_FEATURE -> { + downloadFeature(call, result) + } + + RUN_INFERENCE -> { + runInference(call, result) + } + + RUN_INFERENCE_STREAMING -> { + result.notImplemented() + } + + CLOSE -> { + closeSummarizer(call) + result.success(null) + } + + else -> { + result.notImplemented() + } + } + } + + private fun initialize(): com.google.mlkit.genai.summarization.Summarizer { + val options = SummarizerOptions.builder(context).build() + return Summarization.getClient(options) + } + + private fun checkFeatureStatus( + call: MethodCall, + result: MethodChannel.Result, + ) { + val id = call.argument("id") ?: return + val summarizer = instances.getOrPut(id) { initialize() } + + val future = summarizer.checkFeatureStatus() + Futures.addCallback( + future, + object : FutureCallback { + override fun onSuccess(status: Int?) { + val statusValue = + when (status) { + FeatureStatus.UNAVAILABLE -> 0 + FeatureStatus.DOWNLOADABLE -> 1 + FeatureStatus.DOWNLOADING -> 2 + FeatureStatus.AVAILABLE -> 3 + else -> 0 + } + result.success(statusValue) + } + + override fun onFailure(e: Throwable) { + result.error("SummarizerError", e.toString(), null) + } + }, + executor, + ) + } + + private fun downloadFeature( + call: MethodCall, + result: MethodChannel.Result, + ) { + val id = call.argument("id") ?: return + val summarizer = instances.getOrPut(id) { initialize() } + + summarizer.downloadFeature( + object : DownloadCallback { + override fun onDownloadStarted(bytesToDownload: Long) {} + + override fun onDownloadFailed(e: GenAiException) { + result.error("DownloadError", e.toString(), null) + } + + override fun onDownloadProgress(totalBytesDownloaded: Long) {} + + override fun onDownloadCompleted() { + result.success(null) + } + }, + ) + } + + private fun runInference( + call: MethodCall, + result: MethodChannel.Result, + ) { + val id = call.argument("id") ?: return + val text = call.argument("text") ?: return + val summarizer = instances.getOrPut(id) { initialize() } + + val request = SummarizationRequest.builder(text).build() + val future = summarizer.runInference(request) + + Futures.addCallback( + future, + object : FutureCallback { + override fun onSuccess(summarizationResult: SummarizationResult?) { + result.success(mapOf("summary" to summarizationResult?.summary)) + } + + override fun onFailure(e: Throwable) { + result.error("InferenceError", e.toString(), null) + } + }, + executor, + ) + } + + private fun closeSummarizer(call: MethodCall) { + val id = call.argument("id") ?: return + instances.remove(id)?.close() + } +} diff --git a/packages/google_mlkit_image_labeling/android/build.gradle b/packages/google_mlkit_image_labeling/android/build.gradle index 4fb868e8..6d2768b3 100644 --- a/packages/google_mlkit_image_labeling/android/build.gradle +++ b/packages/google_mlkit_image_labeling/android/build.gradle @@ -2,6 +2,7 @@ group = "com.google_mlkit_image_labeling" version = "1.0" buildscript { + ext.kotlin_version = "2.2.20" repositories { google() mavenCentral() @@ -9,6 +10,7 @@ buildscript { dependencies { classpath("com.android.tools.build:gradle:8.13.0") + classpath('org.jetbrains.kotlin:kotlin-gradle-plugin:2.2.20') } } @@ -20,6 +22,7 @@ rootProject.allprojects { } apply plugin: "com.android.library" +apply plugin: "kotlin-android" android { namespace = "com.google_mlkit_image_labeling" @@ -31,6 +34,14 @@ android { targetCompatibility = JavaVersion.VERSION_11 } + kotlinOptions { + jvmTarget = '11' + } + + sourceSets { + main.java.srcDirs += "src/main/kotlin" + } + defaultConfig { minSdk = 21 } diff --git a/packages/google_mlkit_image_labeling/android/gradle/wrapper/gradle-wrapper.properties b/packages/google_mlkit_image_labeling/android/gradle/wrapper/gradle-wrapper.properties new file mode 100644 index 00000000..128196a7 --- /dev/null +++ b/packages/google_mlkit_image_labeling/android/gradle/wrapper/gradle-wrapper.properties @@ -0,0 +1,7 @@ +distributionBase=GRADLE_USER_HOME +distributionPath=wrapper/dists +distributionUrl=https\://services.gradle.org/distributions/gradle-9.0-milestone-1-bin.zip +networkTimeout=10000 +validateDistributionUrl=true +zipStoreBase=GRADLE_USER_HOME +zipStorePath=wrapper/dists diff --git a/packages/google_mlkit_image_labeling/android/src/main/java/com/google_mlkit_image_labeling/GoogleMlKitImageLabelingPlugin.java b/packages/google_mlkit_image_labeling/android/src/main/java/com/google_mlkit_image_labeling/GoogleMlKitImageLabelingPlugin.java deleted file mode 100644 index 47a83840..00000000 --- a/packages/google_mlkit_image_labeling/android/src/main/java/com/google_mlkit_image_labeling/GoogleMlKitImageLabelingPlugin.java +++ /dev/null @@ -1,22 +0,0 @@ -package com.google_mlkit_image_labeling; - -import androidx.annotation.NonNull; - -import io.flutter.embedding.engine.plugins.FlutterPlugin; -import io.flutter.plugin.common.MethodChannel; - -public class GoogleMlKitImageLabelingPlugin implements FlutterPlugin { - private MethodChannel channel; - private static final String channelName = "google_mlkit_image_labeler"; - - @Override - public void onAttachedToEngine(@NonNull FlutterPluginBinding flutterPluginBinding) { - channel = new MethodChannel(flutterPluginBinding.getBinaryMessenger(), channelName); - channel.setMethodCallHandler(new ImageLabelDetector(flutterPluginBinding.getApplicationContext())); - } - - @Override - public void onDetachedFromEngine(@NonNull FlutterPluginBinding binding) { - channel.setMethodCallHandler(null); - } -} diff --git a/packages/google_mlkit_image_labeling/android/src/main/java/com/google_mlkit_image_labeling/ImageLabelDetector.java b/packages/google_mlkit_image_labeling/android/src/main/java/com/google_mlkit_image_labeling/ImageLabelDetector.java deleted file mode 100644 index 93bc20e5..00000000 --- a/packages/google_mlkit_image_labeling/android/src/main/java/com/google_mlkit_image_labeling/ImageLabelDetector.java +++ /dev/null @@ -1,184 +0,0 @@ -package com.google_mlkit_image_labeling; - -import android.content.Context; - -import androidx.annotation.NonNull; - -import com.google.mlkit.common.model.CustomRemoteModel; -import com.google.mlkit.common.model.LocalModel; -import com.google.mlkit.linkfirebase.FirebaseModelSource; -import com.google.mlkit.vision.common.InputImage; -import com.google.mlkit.vision.label.ImageLabel; -import com.google.mlkit.vision.label.ImageLabeler; -import com.google.mlkit.vision.label.ImageLabeling; -import com.google.mlkit.vision.label.custom.CustomImageLabelerOptions; -import com.google.mlkit.vision.label.defaults.ImageLabelerOptions; -import com.google_mlkit_commons.GenericModelManager; -import com.google_mlkit_commons.InputImageConverter; - -import java.util.ArrayList; -import java.util.HashMap; -import java.util.List; -import java.util.Map; - -import io.flutter.plugin.common.MethodCall; -import io.flutter.plugin.common.MethodChannel; - -public class ImageLabelDetector implements MethodChannel.MethodCallHandler { - private static final String START = "vision#startImageLabelDetector"; - private static final String CLOSE = "vision#closeImageLabelDetector"; - private static final String MANAGE = "vision#manageFirebaseModels"; - - private final Context context; - private final Map instances = new HashMap<>(); - private final GenericModelManager genericModelManager = new GenericModelManager(); - - public ImageLabelDetector(Context context) { - this.context = context; - } - - @Override - public void onMethodCall(@NonNull MethodCall call, @NonNull MethodChannel.Result result) { - String method = call.method; - switch (method) { - case START: - handleDetection(call, result); - break; - case CLOSE: - closeDetector(call); - result.success(null); - break; - case MANAGE: - manageModel(call, result); - break; - default: - result.notImplemented(); - break; - } - } - - private void handleDetection(MethodCall call, final MethodChannel.Result result) { - Map imageData = call.argument("imageData"); - InputImage inputImage = InputImageConverter.getInputImageFromData(imageData, context, result); - if (inputImage == null) return; - - String id = call.argument("id"); - ImageLabeler imageLabeler = instances.get(id); - if (imageLabeler == null) { - Map options = call.argument("options"); - if (options == null) { - result.error("ImageLabelDetectorError", "Invalid options", null); - return; - } - - String type = (String) options.get("type"); - if (type.equals("base")) { - ImageLabelerOptions labelerOptions = getDefaultOptions(options); - imageLabeler = ImageLabeling.getClient(labelerOptions); - } else if (type.equals("local")) { - CustomImageLabelerOptions labelerOptions = getLocalOptions(options); - imageLabeler = ImageLabeling.getClient(labelerOptions); - } else if (type.equals("remote")) { - float confidenceThreshold = (float) (double) options.get("confidenceThreshold"); - int maxCount = (int) options.get("maxCount"); - String name = (String) options.get("modelName"); - - FirebaseModelSource firebaseModelSource = new FirebaseModelSource.Builder(name).build(); - CustomRemoteModel remoteModel = new CustomRemoteModel.Builder(firebaseModelSource).build(); - - genericModelManager.isModelDownloaded( - remoteModel, - new GenericModelManager.CheckModelIsDownloadedCallback() { - @Override - public void onCheckResult(Boolean isDownloaded) { - if (!isDownloaded) { - result.error("Error Model has not been downloaded yet", "Model has not been downloaded yet", "Model has not been downloaded yet"); - return; - } - - startImageLabelDetector( - ImageLabeling.getClient( - new CustomImageLabelerOptions.Builder(remoteModel) - .setConfidenceThreshold(confidenceThreshold) - .setMaxResultCount(maxCount) - .build() - ), - inputImage, - result - ); - } - - @Override - public void onError(Exception e) { - result.error("Model download check failed", e.getMessage(), e); - } - } - ); - - return; - } else { - String error = "Invalid model type: " + type; - result.error(type, error, error); - return; - } - instances.put(id, imageLabeler); - } - - startImageLabelDetector(imageLabeler, inputImage, result); - } - - private void startImageLabelDetector(ImageLabeler imageLabeler, InputImage inputImage, MethodChannel.Result result) { - imageLabeler.process(inputImage) - .addOnSuccessListener(imageLabels -> { - List> labels = new ArrayList<>(imageLabels.size()); - for (ImageLabel label : imageLabels) { - Map labelData = new HashMap<>(); - labelData.put("text", label.getText()); - labelData.put("confidence", label.getConfidence()); - labelData.put("index", label.getIndex()); - labels.add(labelData); - } - - result.success(labels); - }) - .addOnFailureListener(e -> result.error("ImageLabelDetectorError", e.toString(), null)); - } - - //Labeler options that are provided to default image labeler(uses inbuilt model). - private ImageLabelerOptions getDefaultOptions(Map labelerOptions) { - float confidenceThreshold = (float) (double) labelerOptions.get("confidenceThreshold"); - return new ImageLabelerOptions.Builder() - .setConfidenceThreshold(confidenceThreshold) - .build(); - } - - //Options for labeler to work with custom model. - private CustomImageLabelerOptions getLocalOptions(Map labelerOptions) { - float confidenceThreshold = (float) (double) labelerOptions.get("confidenceThreshold"); - int maxCount = (int) labelerOptions.get("maxCount"); - String path = (String) labelerOptions.get("path"); - LocalModel localModel = new LocalModel.Builder() - .setAbsoluteFilePath(path) - .build(); - return new CustomImageLabelerOptions.Builder(localModel) - .setConfidenceThreshold(confidenceThreshold) - .setMaxResultCount(maxCount) - .build(); - } - - private void closeDetector(MethodCall call) { - String id = call.argument("id"); - ImageLabeler imageLabeler = instances.get(id); - if (imageLabeler == null) return; - imageLabeler.close(); - instances.remove(id); - } - - private void manageModel(MethodCall call, final MethodChannel.Result result) { - FirebaseModelSource firebaseModelSource = new FirebaseModelSource.Builder(call.argument("model")) - .build(); - CustomRemoteModel model = new CustomRemoteModel.Builder(firebaseModelSource) - .build(); - genericModelManager.manageModel(model, call, result); - } -} diff --git a/packages/google_mlkit_image_labeling/android/src/main/kotlin/com/google_mlkit_image_labeling/GoogleMlKitImageLabelingPlugin.kt b/packages/google_mlkit_image_labeling/android/src/main/kotlin/com/google_mlkit_image_labeling/GoogleMlKitImageLabelingPlugin.kt new file mode 100644 index 00000000..c94c5f6e --- /dev/null +++ b/packages/google_mlkit_image_labeling/android/src/main/kotlin/com/google_mlkit_image_labeling/GoogleMlKitImageLabelingPlugin.kt @@ -0,0 +1,21 @@ +package com.google_mlkit_image_labeling + +import io.flutter.embedding.engine.plugins.FlutterPlugin +import io.flutter.plugin.common.MethodChannel + +class GoogleMlKitImageLabelingPlugin : FlutterPlugin { + private lateinit var channel: MethodChannel + + companion object { + private const val CHANNEL_NAME = "google_mlkit_image_labeler" + } + + override fun onAttachedToEngine(flutterPluginBinding: FlutterPlugin.FlutterPluginBinding) { + channel = MethodChannel(flutterPluginBinding.binaryMessenger, CHANNEL_NAME) + channel.setMethodCallHandler(ImageLabelDetector(flutterPluginBinding.applicationContext)) + } + + override fun onDetachedFromEngine(binding: FlutterPlugin.FlutterPluginBinding) { + channel.setMethodCallHandler(null) + } +} diff --git a/packages/google_mlkit_image_labeling/android/src/main/kotlin/com/google_mlkit_image_labeling/ImageLabelDetector.kt b/packages/google_mlkit_image_labeling/android/src/main/kotlin/com/google_mlkit_image_labeling/ImageLabelDetector.kt new file mode 100644 index 00000000..8dfb6fbf --- /dev/null +++ b/packages/google_mlkit_image_labeling/android/src/main/kotlin/com/google_mlkit_image_labeling/ImageLabelDetector.kt @@ -0,0 +1,195 @@ +package com.google_mlkit_image_labeling + +import android.content.Context +import com.google.mlkit.common.model.CustomRemoteModel +import com.google.mlkit.common.model.LocalModel +import com.google.mlkit.linkfirebase.FirebaseModelSource +import com.google.mlkit.vision.label.ImageLabeler +import com.google.mlkit.vision.label.ImageLabeling +import com.google.mlkit.vision.label.custom.CustomImageLabelerOptions +import com.google.mlkit.vision.label.defaults.ImageLabelerOptions +import com.google_mlkit_commons.GenericModelManager +import com.google_mlkit_commons.InputImageConverter +import io.flutter.plugin.common.MethodCall +import io.flutter.plugin.common.MethodChannel + +class ImageLabelDetector( + private val context: Context, +) : MethodChannel.MethodCallHandler { + private val instances = mutableMapOf() + private val genericModelManager = GenericModelManager() + + companion object { + private const val START = "vision#startImageLabelDetector" + private const val CLOSE = "vision#closeImageLabelDetector" + private const val MANAGE = "vision#manageFirebaseModels" + } + + override fun onMethodCall( + call: MethodCall, + result: MethodChannel.Result, + ) { + when (call.method) { + START -> { + handleDetection(call, result) + } + + CLOSE -> { + closeDetector(call) + result.success(null) + } + + MANAGE -> { + manageModel(call, result) + } + + else -> { + result.notImplemented() + } + } + } + + private fun handleDetection( + call: MethodCall, + result: MethodChannel.Result, + ) { + val imageData = + call.argument>("imageData") ?: run { + result.error("ImageLabelDetectorError", "imageData is null", null) + return + } + val inputImage = InputImageConverter.getInputImageFromData(imageData, context, result) ?: return + + val id = call.argument("id") ?: return + var imageLabeler = instances[id] + + if (imageLabeler == null) { + val options = + call.argument>("options") ?: run { + result.error("ImageLabelDetectorError", "Invalid options", null) + return + } + + when (val type = options["type"] as? String) { + "base" -> { + imageLabeler = ImageLabeling.getClient(getDefaultOptions(options)) + } + + "local" -> { + imageLabeler = ImageLabeling.getClient(getLocalOptions(options)) + } + + "remote" -> { + val confidenceThreshold = (options["confidenceThreshold"] as Double).toFloat() + val maxCount = options["maxCount"] as Int + val name = options["modelName"] as String + + val firebaseModelSource = FirebaseModelSource.Builder(name).build() + val remoteModel = CustomRemoteModel.Builder(firebaseModelSource).build() + + genericModelManager.isModelDownloaded( + remoteModel, + object : GenericModelManager.CheckModelIsDownloadedCallback { + override fun onCheckResult(isDownloaded: Boolean?) { + if (isDownloaded != true) { + result.error( + "Error Model has not been downloaded yet", + "Model has not been downloaded yet", + "Model has not been downloaded yet", + ) + return + } + startImageLabelDetector( + ImageLabeling.getClient( + CustomImageLabelerOptions + .Builder(remoteModel) + .setConfidenceThreshold(confidenceThreshold) + .setMaxResultCount(maxCount) + .build(), + ), + inputImage, + result, + ) + } + + override fun onError(e: Exception) { + result.error("Model download check failed", e.message, e) + } + }, + ) + return + } + + else -> { + val error = "Invalid model type: $type" + result.error(type ?: "unknown", error, error) + return + } + } + instances[id] = imageLabeler!! + } + + startImageLabelDetector(imageLabeler, inputImage, result) + } + + private fun startImageLabelDetector( + imageLabeler: ImageLabeler, + inputImage: com.google.mlkit.vision.common.InputImage, + result: MethodChannel.Result, + ) { + imageLabeler + .process(inputImage) + .addOnSuccessListener { imageLabels -> + val labels = + imageLabels.map { label -> + mapOf( + "text" to label.text, + "confidence" to label.confidence, + "index" to label.index, + ) + } + result.success(labels) + }.addOnFailureListener { e -> + result.error("ImageLabelDetectorError", e.toString(), null) + } + } + + private fun getDefaultOptions(labelerOptions: Map): ImageLabelerOptions { + val confidenceThreshold = (labelerOptions["confidenceThreshold"] as Double).toFloat() + return ImageLabelerOptions + .Builder() + .setConfidenceThreshold(confidenceThreshold) + .build() + } + + private fun getLocalOptions(labelerOptions: Map): CustomImageLabelerOptions { + val confidenceThreshold = (labelerOptions["confidenceThreshold"] as Double).toFloat() + val maxCount = labelerOptions["maxCount"] as Int + val path = labelerOptions["path"] as String + val localModel = LocalModel.Builder().setAbsoluteFilePath(path).build() + return CustomImageLabelerOptions + .Builder(localModel) + .setConfidenceThreshold(confidenceThreshold) + .setMaxResultCount(maxCount) + .build() + } + + private fun closeDetector(call: MethodCall) { + val id = call.argument("id") ?: return + instances.remove(id)?.close() + } + + private fun manageModel( + call: MethodCall, + result: MethodChannel.Result, + ) { + val modelName = + call.argument("model") ?: run { + result.error("ImageLabelDetectorError", "Model name is null", null) + return + } + val firebaseModelSource = FirebaseModelSource.Builder(modelName).build() + val model = CustomRemoteModel.Builder(firebaseModelSource).build() + genericModelManager.manageModel(model, call, result) + } +} diff --git a/packages/google_mlkit_language_id/android/build.gradle b/packages/google_mlkit_language_id/android/build.gradle index 48246bcb..c67154e4 100644 --- a/packages/google_mlkit_language_id/android/build.gradle +++ b/packages/google_mlkit_language_id/android/build.gradle @@ -2,6 +2,7 @@ group = "com.google_mlkit_language_id" version = "1.0" buildscript { + ext.kotlin_version = "2.2.20" repositories { google() mavenCentral() @@ -9,6 +10,7 @@ buildscript { dependencies { classpath("com.android.tools.build:gradle:8.13.0") + classpath('org.jetbrains.kotlin:kotlin-gradle-plugin:2.2.20') } } @@ -20,6 +22,7 @@ rootProject.allprojects { } apply plugin: "com.android.library" +apply plugin: "kotlin-android" android { namespace = "com.google_mlkit_language_id" @@ -30,6 +33,13 @@ android { sourceCompatibility = JavaVersion.VERSION_11 targetCompatibility = JavaVersion.VERSION_11 } + kotlinOptions { + jvmTarget = '11' + } + + sourceSets { + main.java.srcDirs += "src/main/kotlin" + } defaultConfig { minSdk = 21 diff --git a/packages/google_mlkit_language_id/android/gradle/wrapper/gradle-wrapper.properties b/packages/google_mlkit_language_id/android/gradle/wrapper/gradle-wrapper.properties new file mode 100644 index 00000000..128196a7 --- /dev/null +++ b/packages/google_mlkit_language_id/android/gradle/wrapper/gradle-wrapper.properties @@ -0,0 +1,7 @@ +distributionBase=GRADLE_USER_HOME +distributionPath=wrapper/dists +distributionUrl=https\://services.gradle.org/distributions/gradle-9.0-milestone-1-bin.zip +networkTimeout=10000 +validateDistributionUrl=true +zipStoreBase=GRADLE_USER_HOME +zipStorePath=wrapper/dists diff --git a/packages/google_mlkit_language_id/android/src/main/java/com/google_mlkit_language_id/GoogleMlKitLanguageIdPlugin.java b/packages/google_mlkit_language_id/android/src/main/java/com/google_mlkit_language_id/GoogleMlKitLanguageIdPlugin.java deleted file mode 100644 index b9f956db..00000000 --- a/packages/google_mlkit_language_id/android/src/main/java/com/google_mlkit_language_id/GoogleMlKitLanguageIdPlugin.java +++ /dev/null @@ -1,22 +0,0 @@ -package com.google_mlkit_language_id; - -import androidx.annotation.NonNull; - -import io.flutter.embedding.engine.plugins.FlutterPlugin; -import io.flutter.plugin.common.MethodChannel; - -public class GoogleMlKitLanguageIdPlugin implements FlutterPlugin { - private MethodChannel channel; - private static final String channelName = "google_mlkit_language_identifier"; - - @Override - public void onAttachedToEngine(@NonNull FlutterPluginBinding flutterPluginBinding) { - channel = new MethodChannel(flutterPluginBinding.getBinaryMessenger(), channelName); - channel.setMethodCallHandler(new LanguageDetector()); - } - - @Override - public void onDetachedFromEngine(@NonNull FlutterPluginBinding binding) { - channel.setMethodCallHandler(null); - } -} diff --git a/packages/google_mlkit_language_id/android/src/main/java/com/google_mlkit_language_id/LanguageDetector.java b/packages/google_mlkit_language_id/android/src/main/java/com/google_mlkit_language_id/LanguageDetector.java deleted file mode 100644 index 17deb735..00000000 --- a/packages/google_mlkit_language_id/android/src/main/java/com/google_mlkit_language_id/LanguageDetector.java +++ /dev/null @@ -1,90 +0,0 @@ -package com.google_mlkit_language_id; - -import androidx.annotation.NonNull; - -import com.google.mlkit.nl.languageid.IdentifiedLanguage; -import com.google.mlkit.nl.languageid.LanguageIdentification; -import com.google.mlkit.nl.languageid.LanguageIdentificationOptions; -import com.google.mlkit.nl.languageid.LanguageIdentifier; - -import java.util.ArrayList; -import java.util.HashMap; -import java.util.List; -import java.util.Map; - -import io.flutter.plugin.common.MethodCall; -import io.flutter.plugin.common.MethodChannel; - -public class LanguageDetector implements MethodChannel.MethodCallHandler { - private static final String START = "nlp#startLanguageIdentifier"; - private static final String CLOSE = "nlp#closeLanguageIdentifier"; - - private final Map instances = new HashMap<>(); - - @Override - public void onMethodCall(@NonNull MethodCall call, @NonNull MethodChannel.Result result) { - String method = call.method; - switch (method) { - case START: - identifyLanguages(call, result); - break; - case CLOSE: - closeDetector(call); - result.success(null); - break; - default: - result.notImplemented(); - break; - } - } - - private void identifyLanguages(MethodCall call, final MethodChannel.Result result) { - String id = call.argument("id"); - LanguageIdentifier languageIdentifier = instances.get(id); - if (languageIdentifier == null) { - double confidence = (double) call.argument("confidence"); - languageIdentifier = LanguageIdentification.getClient( - new LanguageIdentificationOptions.Builder() - .setConfidenceThreshold((float) confidence) - .build()); - instances.put(id, languageIdentifier); - } - - boolean possibleLanguages = (boolean) call.argument("possibleLanguages"); - String text = (String) call.argument("text"); - if (!possibleLanguages) { - identifyLanguage(text, languageIdentifier, result); - } else { - identifyPossibleLanguages(text, languageIdentifier, result); - } - } - - private void identifyLanguage(String text, LanguageIdentifier languageIdentifier, final MethodChannel.Result result) { - languageIdentifier.identifyLanguage(text) - .addOnSuccessListener(result::success) - .addOnFailureListener(e -> result.error("Language Identification Error", e.toString(), null)); - } - - private void identifyPossibleLanguages(String text, LanguageIdentifier languageIdentifier, final MethodChannel.Result result) { - languageIdentifier.identifyPossibleLanguages(text) - .addOnSuccessListener(identifiedLanguages -> { - List> languageList = new ArrayList<>(); - for (IdentifiedLanguage language : identifiedLanguages) { - Map languageData = new HashMap<>(); - languageData.put("confidence", language.getConfidence()); - languageData.put("language", language.getLanguageTag()); - languageList.add(languageData); - } - result.success(languageList); - }) - .addOnFailureListener(e -> result.error("Error identifying possible languages", e.toString(), null)); - } - - private void closeDetector(MethodCall call) { - String id = call.argument("id"); - LanguageIdentifier languageIdentifier = instances.get(id); - if (languageIdentifier == null) return; - languageIdentifier.close(); - instances.remove(id); - } -} diff --git a/packages/google_mlkit_language_id/android/src/main/kotlin/com/google_mlkit_language_id/GoogleMlKitLanguageIdPlugin.kt b/packages/google_mlkit_language_id/android/src/main/kotlin/com/google_mlkit_language_id/GoogleMlKitLanguageIdPlugin.kt new file mode 100644 index 00000000..4acb8c73 --- /dev/null +++ b/packages/google_mlkit_language_id/android/src/main/kotlin/com/google_mlkit_language_id/GoogleMlKitLanguageIdPlugin.kt @@ -0,0 +1,21 @@ +package com.google_mlkit_language_id + +import io.flutter.embedding.engine.plugins.FlutterPlugin +import io.flutter.plugin.common.MethodChannel + +class GoogleMlKitLanguageIdPlugin : FlutterPlugin { + private lateinit var channel: MethodChannel + + companion object { + private const val CHANNEL_NAME = "google_mlkit_language_identifier" + } + + override fun onAttachedToEngine(flutterPluginBinding: FlutterPlugin.FlutterPluginBinding) { + channel = MethodChannel(flutterPluginBinding.binaryMessenger, CHANNEL_NAME) + channel.setMethodCallHandler(LanguageDetector()) + } + + override fun onDetachedFromEngine(binding: FlutterPlugin.FlutterPluginBinding) { + channel.setMethodCallHandler(null) + } +} diff --git a/packages/google_mlkit_language_id/android/src/main/kotlin/com/google_mlkit_language_id/LanguageDetector.kt b/packages/google_mlkit_language_id/android/src/main/kotlin/com/google_mlkit_language_id/LanguageDetector.kt new file mode 100644 index 00000000..713dd644 --- /dev/null +++ b/packages/google_mlkit_language_id/android/src/main/kotlin/com/google_mlkit_language_id/LanguageDetector.kt @@ -0,0 +1,97 @@ +package com.google_mlkit_language_id + +import com.google.mlkit.nl.languageid.LanguageIdentification +import com.google.mlkit.nl.languageid.LanguageIdentificationOptions +import com.google.mlkit.nl.languageid.LanguageIdentifier +import io.flutter.plugin.common.MethodCall +import io.flutter.plugin.common.MethodChannel + +class LanguageDetector : MethodChannel.MethodCallHandler { + private val instances = mutableMapOf() + + companion object { + private const val START = "nlp#startLanguageIdentifier" + private const val CLOSE = "nlp#closeLanguageIdentifier" + } + + override fun onMethodCall( + call: MethodCall, + result: MethodChannel.Result, + ) { + when (call.method) { + START -> { + identifyLanguages(call, result) + } + + CLOSE -> { + closeDetector(call) + result.success(null) + } + + else -> { + result.notImplemented() + } + } + } + + private fun identifyLanguages( + call: MethodCall, + result: MethodChannel.Result, + ) { + val id = call.argument("id") ?: return + val languageIdentifier = + instances.getOrPut(id) { + val confidence = call.argument("confidence") ?: 0.0 + LanguageIdentification.getClient( + LanguageIdentificationOptions + .Builder() + .setConfidenceThreshold(confidence.toFloat()) + .build(), + ) + } + + val possibleLanguages = call.argument("possibleLanguages") ?: false + val text = call.argument("text") ?: return + + if (!possibleLanguages) { + identifyLanguage(text, languageIdentifier, result) + } else { + identifyPossibleLanguages(text, languageIdentifier, result) + } + } + + private fun identifyLanguage( + text: String, + languageIdentifier: LanguageIdentifier, + result: MethodChannel.Result, + ) { + languageIdentifier + .identifyLanguage(text) + .addOnSuccessListener { result.success(it) } + .addOnFailureListener { e -> result.error("Language Identification Error", e.toString(), null) } + } + + private fun identifyPossibleLanguages( + text: String, + languageIdentifier: LanguageIdentifier, + result: MethodChannel.Result, + ) { + languageIdentifier + .identifyPossibleLanguages(text) + .addOnSuccessListener { identifiedLanguages -> + val languageList = + identifiedLanguages.map { language -> + mapOf( + "confidence" to language.confidence, + "language" to language.languageTag, + ) + } + result.success(languageList) + }.addOnFailureListener { e -> result.error("Error identifying possible languages", e.toString(), null) } + } + + private fun closeDetector(call: MethodCall) { + val id = call.argument("id") ?: return + instances.remove(id)?.close() + } +} diff --git a/packages/google_mlkit_object_detection/android/build.gradle b/packages/google_mlkit_object_detection/android/build.gradle index 1da67cb4..23718cc5 100644 --- a/packages/google_mlkit_object_detection/android/build.gradle +++ b/packages/google_mlkit_object_detection/android/build.gradle @@ -2,6 +2,7 @@ group = "com.google_mlkit_object_detection" version = "1.0" buildscript { + ext.kotlin_version = "2.2.20" repositories { google() mavenCentral() @@ -9,6 +10,7 @@ buildscript { dependencies { classpath("com.android.tools.build:gradle:8.13.0") + classpath('org.jetbrains.kotlin:kotlin-gradle-plugin:2.2.20') } } @@ -20,6 +22,7 @@ rootProject.allprojects { } apply plugin: "com.android.library" +apply plugin: "kotlin-android" android { namespace = "com.google_mlkit_object_detection" @@ -31,6 +34,14 @@ android { targetCompatibility = JavaVersion.VERSION_11 } + kotlinOptions { + jvmTarget = '11' + } + + sourceSets { + main.java.srcDirs += "src/main/kotlin" + } + defaultConfig { minSdk = 21 } diff --git a/packages/google_mlkit_object_detection/android/gradle/wrapper/gradle-wrapper.properties b/packages/google_mlkit_object_detection/android/gradle/wrapper/gradle-wrapper.properties new file mode 100644 index 00000000..128196a7 --- /dev/null +++ b/packages/google_mlkit_object_detection/android/gradle/wrapper/gradle-wrapper.properties @@ -0,0 +1,7 @@ +distributionBase=GRADLE_USER_HOME +distributionPath=wrapper/dists +distributionUrl=https\://services.gradle.org/distributions/gradle-9.0-milestone-1-bin.zip +networkTimeout=10000 +validateDistributionUrl=true +zipStoreBase=GRADLE_USER_HOME +zipStorePath=wrapper/dists diff --git a/packages/google_mlkit_object_detection/android/src/main/java/com/google_mlkit_object_detection/GoogleMlKitObjectDetectionPlugin.java b/packages/google_mlkit_object_detection/android/src/main/java/com/google_mlkit_object_detection/GoogleMlKitObjectDetectionPlugin.java deleted file mode 100644 index 598bbeb4..00000000 --- a/packages/google_mlkit_object_detection/android/src/main/java/com/google_mlkit_object_detection/GoogleMlKitObjectDetectionPlugin.java +++ /dev/null @@ -1,22 +0,0 @@ -package com.google_mlkit_object_detection; - -import androidx.annotation.NonNull; - -import io.flutter.embedding.engine.plugins.FlutterPlugin; -import io.flutter.plugin.common.MethodChannel; - -public class GoogleMlKitObjectDetectionPlugin implements FlutterPlugin { - private MethodChannel channel; - private static final String channelName = "google_mlkit_object_detector"; - - @Override - public void onAttachedToEngine(@NonNull FlutterPluginBinding flutterPluginBinding) { - channel = new MethodChannel(flutterPluginBinding.getBinaryMessenger(), channelName); - channel.setMethodCallHandler(new ObjectDetector(flutterPluginBinding.getApplicationContext())); - } - - @Override - public void onDetachedFromEngine(@NonNull FlutterPluginBinding binding) { - channel.setMethodCallHandler(null); - } -} diff --git a/packages/google_mlkit_object_detection/android/src/main/java/com/google_mlkit_object_detection/ObjectDetector.java b/packages/google_mlkit_object_detection/android/src/main/java/com/google_mlkit_object_detection/ObjectDetector.java deleted file mode 100644 index 45cd674b..00000000 --- a/packages/google_mlkit_object_detection/android/src/main/java/com/google_mlkit_object_detection/ObjectDetector.java +++ /dev/null @@ -1,248 +0,0 @@ -package com.google_mlkit_object_detection; - -import android.content.Context; -import android.graphics.Rect; - -import androidx.annotation.NonNull; - -import com.google.mlkit.common.model.CustomRemoteModel; -import com.google.mlkit.common.model.LocalModel; -import com.google.mlkit.linkfirebase.FirebaseModelSource; -import com.google.mlkit.vision.common.InputImage; -import com.google.mlkit.vision.objects.DetectedObject; -import com.google.mlkit.vision.objects.ObjectDetection; -import com.google.mlkit.vision.objects.custom.CustomObjectDetectorOptions; -import com.google.mlkit.vision.objects.defaults.ObjectDetectorOptions; -import com.google_mlkit_commons.GenericModelManager; -import com.google_mlkit_commons.InputImageConverter; - -import java.util.ArrayList; -import java.util.HashMap; -import java.util.List; -import java.util.Map; - -import io.flutter.plugin.common.MethodCall; -import io.flutter.plugin.common.MethodChannel; - -public class ObjectDetector implements MethodChannel.MethodCallHandler { - private static final String START = "vision#startObjectDetector"; - private static final String CLOSE = "vision#closeObjectDetector"; - private static final String MANAGE = "vision#manageFirebaseModels"; - - private final Context context; - private final Map instances = new HashMap<>(); - private final GenericModelManager genericModelManager = new GenericModelManager(); - - public ObjectDetector(Context context) { - this.context = context; - } - - @Override - public void onMethodCall(@NonNull MethodCall call, @NonNull MethodChannel.Result result) { - String method = call.method; - switch (method) { - case START: - handleDetection(call, result); - break; - case CLOSE: - closeDetector(call); - result.success(null); - break; - case MANAGE: - manageModel(call, result); - break; - default: - result.notImplemented(); - break; - } - } - - private void handleDetection(MethodCall call, final MethodChannel.Result result) { - Map imageData = call.argument("imageData"); - InputImage inputImage = InputImageConverter.getInputImageFromData(imageData, context, result); - if (inputImage == null) return; - - String id = call.argument("id"); - com.google.mlkit.vision.objects.ObjectDetector objectDetector = instances.get(id); - if (objectDetector == null) { - Map options = call.argument("options"); - if (options == null) { - result.error("ImageLabelDetectorError", "Invalid options", null); - return; - } - - String type = (String) options.get("type"); - if (type.equals("base")) { - ObjectDetectorOptions detectorOptions = getDefaultOptions(options); - objectDetector = ObjectDetection.getClient(detectorOptions); - } else if (type.equals("local")) { - CustomObjectDetectorOptions detectorOptions = getLocalOptions(options); - objectDetector = ObjectDetection.getClient(detectorOptions); - } else if (type.equals("remote")) { - int mode = (int) options.get("mode"); - int finalMode = mode == 0 ? - CustomObjectDetectorOptions.STREAM_MODE : - CustomObjectDetectorOptions.SINGLE_IMAGE_MODE; - boolean classify = (boolean) options.get("classify"); - boolean multiple = (boolean) options.get("multiple"); - double threshold = (double) options.get("threshold"); - int maxLabels = (int) options.get("maxLabels"); - String name = (String) options.get("modelName"); - - FirebaseModelSource firebaseModelSource = new FirebaseModelSource.Builder(name) - .build(); - CustomRemoteModel remoteModel = new CustomRemoteModel.Builder(firebaseModelSource) - .build(); - - genericModelManager.isModelDownloaded( - remoteModel, - new GenericModelManager.CheckModelIsDownloadedCallback() { - @Override - public void onCheckResult(Boolean isDownloaded) { - if (!isDownloaded) { - result.error("Error Model has not been downloaded yet", "Model has not been downloaded yet", "Model has not been downloaded yet"); - return; - } - - CustomObjectDetectorOptions.Builder builder = new CustomObjectDetectorOptions.Builder(remoteModel) - .setDetectorMode(finalMode) - .setMaxPerObjectLabelCount(maxLabels) - .setClassificationConfidenceThreshold((float) threshold); - if (classify) builder.enableClassification(); - if (multiple) builder.enableMultipleObjects(); - - CustomObjectDetectorOptions customObjectDetectorOptions = builder.build(); - - startObjectDetection( - ObjectDetection.getClient(customObjectDetectorOptions), - inputImage, - result - ); - } - - @Override - public void onError(Exception e) { - result.error("Model download check failed", e.getMessage(), e); - } - } - ); - - return; - } else { - String error = "Invalid model type: " + type; - result.error(type, error, error); - return; - } - instances.put(id, objectDetector); - } - - startObjectDetection(objectDetector, inputImage, result); - } - - private void startObjectDetection( - com.google.mlkit.vision.objects.ObjectDetector objectDetector, - InputImage inputImage, - MethodChannel.Result result - ) { - objectDetector.process(inputImage).addOnSuccessListener(detectedObjects -> { - List> objects = new ArrayList<>(); - for (DetectedObject detectedObject : detectedObjects) { - Map objectMap = new HashMap<>(); - addData(objectMap, - detectedObject.getTrackingId(), - detectedObject.getBoundingBox(), - detectedObject.getLabels()); - objects.add(objectMap); - } - result.success(objects); - }).addOnFailureListener(e -> { - e.printStackTrace(); - result.error("ObjectDetectionError", e.toString(), null); - }); - } - - private ObjectDetectorOptions getDefaultOptions(Map options) { - int mode = (int) options.get("mode"); - mode = mode == 0 ? - ObjectDetectorOptions.STREAM_MODE : - ObjectDetectorOptions.SINGLE_IMAGE_MODE; - boolean classify = (boolean) options.get("classify"); - boolean multiple = (boolean) options.get("multiple"); - - ObjectDetectorOptions.Builder builder = new ObjectDetectorOptions.Builder() - .setDetectorMode(mode); - if (classify) builder.enableClassification(); - if (multiple) builder.enableMultipleObjects(); - return builder.build(); - } - - private CustomObjectDetectorOptions getLocalOptions(Map options) { - int mode = (int) options.get("mode"); - mode = mode == 0 ? - CustomObjectDetectorOptions.STREAM_MODE : - CustomObjectDetectorOptions.SINGLE_IMAGE_MODE; - boolean classify = (boolean) options.get("classify"); - boolean multiple = (boolean) options.get("multiple"); - double threshold = (double) options.get("threshold"); - int maxLabels = (int) options.get("maxLabels"); - String path = (String) options.get("path"); - - LocalModel localModel = new LocalModel.Builder() - .setAbsoluteFilePath(path) - .build(); - - CustomObjectDetectorOptions.Builder builder = new CustomObjectDetectorOptions.Builder(localModel); - builder.setDetectorMode(mode); - if (classify) builder.enableClassification(); - if (multiple) builder.enableMultipleObjects(); - builder.setMaxPerObjectLabelCount(maxLabels); - builder.setClassificationConfidenceThreshold((float) threshold); - return builder.build(); - } - - private void addData(Map addTo, - Integer trackingId, - Rect rect, - List labelList) { - List> labels = new ArrayList<>(); - addLabels(labels, labelList); - addTo.put("rect", getBoundingPoints(rect)); - addTo.put("labels", labels); - addTo.put("trackingId", trackingId); - } - - private Map getBoundingPoints(Rect rect) { - Map frame = new HashMap<>(); - frame.put("left", rect.left); - frame.put("top", rect.top); - frame.put("right", rect.right); - frame.put("bottom", rect.bottom); - return frame; - } - - private void addLabels(List> labels, List labelList) { - for (DetectedObject.Label label : labelList) { - Map labelData = new HashMap<>(); - labelData.put("index", label.getIndex()); - labelData.put("text", label.getText()); - labelData.put("confidence", (double) label.getConfidence()); - labels.add(labelData); - } - } - - private void closeDetector(MethodCall call) { - String id = call.argument("id"); - com.google.mlkit.vision.objects.ObjectDetector objectDetector = instances.get(id); - if (objectDetector == null) return; - objectDetector.close(); - instances.remove(id); - } - - private void manageModel(MethodCall call, final MethodChannel.Result result) { - FirebaseModelSource firebaseModelSource = new FirebaseModelSource.Builder(call.argument("model")) - .build(); - CustomRemoteModel model = new CustomRemoteModel.Builder(firebaseModelSource) - .build(); - genericModelManager.manageModel(model, call, result); - } -} diff --git a/packages/google_mlkit_object_detection/android/src/main/kotlin/com/google_mlkit_object_detection/GoogleMlKitObjectDetectionPlugin.kt b/packages/google_mlkit_object_detection/android/src/main/kotlin/com/google_mlkit_object_detection/GoogleMlKitObjectDetectionPlugin.kt new file mode 100644 index 00000000..0bbe9390 --- /dev/null +++ b/packages/google_mlkit_object_detection/android/src/main/kotlin/com/google_mlkit_object_detection/GoogleMlKitObjectDetectionPlugin.kt @@ -0,0 +1,21 @@ +package com.google_mlkit_object_detection + +import io.flutter.embedding.engine.plugins.FlutterPlugin +import io.flutter.plugin.common.MethodChannel + +class GoogleMlKitObjectDetectionPlugin : FlutterPlugin { + private lateinit var channel: MethodChannel + + companion object { + private const val CHANNEL_NAME = "google_mlkit_object_detector" + } + + override fun onAttachedToEngine(flutterPluginBinding: FlutterPlugin.FlutterPluginBinding) { + channel = MethodChannel(flutterPluginBinding.binaryMessenger, CHANNEL_NAME) + channel.setMethodCallHandler(ObjectDetector(flutterPluginBinding.applicationContext)) + } + + override fun onDetachedFromEngine(binding: FlutterPlugin.FlutterPluginBinding) { + channel.setMethodCallHandler(null) + } +} diff --git a/packages/google_mlkit_object_detection/android/src/main/kotlin/com/google_mlkit_object_detection/ObjectDetector.kt b/packages/google_mlkit_object_detection/android/src/main/kotlin/com/google_mlkit_object_detection/ObjectDetector.kt new file mode 100644 index 00000000..224463dd --- /dev/null +++ b/packages/google_mlkit_object_detection/android/src/main/kotlin/com/google_mlkit_object_detection/ObjectDetector.kt @@ -0,0 +1,260 @@ +package com.google_mlkit_object_detection + +import android.content.Context +import android.graphics.Rect +import com.google.mlkit.common.model.CustomRemoteModel +import com.google.mlkit.common.model.LocalModel +import com.google.mlkit.linkfirebase.FirebaseModelSource +import com.google.mlkit.vision.common.InputImage +import com.google.mlkit.vision.objects.DetectedObject +import com.google.mlkit.vision.objects.ObjectDetection +import com.google.mlkit.vision.objects.custom.CustomObjectDetectorOptions +import com.google.mlkit.vision.objects.defaults.ObjectDetectorOptions +import com.google_mlkit_commons.GenericModelManager +import com.google_mlkit_commons.InputImageConverter +import io.flutter.plugin.common.MethodCall +import io.flutter.plugin.common.MethodChannel + +class ObjectDetector( + private val context: Context, +) : MethodChannel.MethodCallHandler { + private val instances = mutableMapOf() + private val genericModelManager = GenericModelManager() + + companion object { + private const val START = "vision#startObjectDetector" + private const val CLOSE = "vision#closeObjectDetector" + private const val MANAGE = "vision#manageFirebaseModels" + } + + override fun onMethodCall( + call: MethodCall, + result: MethodChannel.Result, + ) { + when (call.method) { + START -> { + handleDetection(call, result) + } + + CLOSE -> { + closeDetector(call) + result.success(null) + } + + MANAGE -> { + manageModel(call, result) + } + + else -> { + result.notImplemented() + } + } + } + + private fun handleDetection( + call: MethodCall, + result: MethodChannel.Result, + ) { + val imageData = + call.argument>("imageData") ?: run { + result.error("ObjectDetectorError", "imageData is null", null) + return + } + val inputImage = InputImageConverter.getInputImageFromData(imageData, context, result) ?: return + + val id = call.argument("id") ?: return + var objectDetector = instances[id] + + if (objectDetector == null) { + val options = + call.argument>("options") ?: run { + result.error("ImageLabelDetectorError", "Invalid options", null) + return + } + + when (val type = options["type"] as? String) { + "base" -> { + objectDetector = ObjectDetection.getClient(getDefaultOptions(options)) + } + + "local" -> { + objectDetector = ObjectDetection.getClient(getLocalOptions(options)) + } + + "remote" -> { + val mode = options["mode"] as Int + val finalMode = + if (mode == 0) { + CustomObjectDetectorOptions.STREAM_MODE + } else { + CustomObjectDetectorOptions.SINGLE_IMAGE_MODE + } + val classify = options["classify"] as Boolean + val multiple = options["multiple"] as Boolean + val threshold = options["threshold"] as Double + val maxLabels = options["maxLabels"] as Int + val name = options["modelName"] as String + + val firebaseModelSource = FirebaseModelSource.Builder(name).build() + val remoteModel = CustomRemoteModel.Builder(firebaseModelSource).build() + + genericModelManager.isModelDownloaded( + remoteModel, + object : GenericModelManager.CheckModelIsDownloadedCallback { + override fun onCheckResult(isDownloaded: Boolean?) { + if (isDownloaded != true) { + result.error( + "Error Model has not been downloaded yet", + "Model has not been downloaded yet", + "Model has not been downloaded yet", + ) + return + } + + val builder = + CustomObjectDetectorOptions + .Builder(remoteModel) + .setDetectorMode(finalMode) + .setMaxPerObjectLabelCount(maxLabels) + .setClassificationConfidenceThreshold(threshold.toFloat()) + if (classify) builder.enableClassification() + if (multiple) builder.enableMultipleObjects() + + startObjectDetection( + ObjectDetection.getClient(builder.build()), + inputImage, + result, + ) + } + + override fun onError(e: Exception) { + result.error("Model download check failed", e.message, e) + } + }, + ) + return + } + + else -> { + val error = "Invalid model type: $type" + result.error(type ?: "unknown", error, error) + return + } + } + instances[id] = objectDetector!! + } + + startObjectDetection(objectDetector, inputImage, result) + } + + private fun startObjectDetection( + objectDetector: com.google.mlkit.vision.objects.ObjectDetector, + inputImage: InputImage, + result: MethodChannel.Result, + ) { + objectDetector + .process(inputImage) + .addOnSuccessListener { detectedObjects -> + val objects = + detectedObjects.map { detectedObject -> + mutableMapOf().apply { + addData(this, detectedObject.trackingId, detectedObject.boundingBox, detectedObject.labels) + } + } + result.success(objects) + }.addOnFailureListener { e -> + e.printStackTrace() + result.error("ObjectDetectionError", e.toString(), null) + } + } + + private fun getDefaultOptions(options: Map): ObjectDetectorOptions { + val mode = + if (options["mode"] as Int == 0) { + ObjectDetectorOptions.STREAM_MODE + } else { + ObjectDetectorOptions.SINGLE_IMAGE_MODE + } + val classify = options["classify"] as Boolean + val multiple = options["multiple"] as Boolean + + return ObjectDetectorOptions + .Builder() + .setDetectorMode(mode) + .apply { + if (classify) enableClassification() + if (multiple) enableMultipleObjects() + }.build() + } + + private fun getLocalOptions(options: Map): CustomObjectDetectorOptions { + val mode = + if (options["mode"] as Int == 0) { + CustomObjectDetectorOptions.STREAM_MODE + } else { + CustomObjectDetectorOptions.SINGLE_IMAGE_MODE + } + val classify = options["classify"] as Boolean + val multiple = options["multiple"] as Boolean + val threshold = options["threshold"] as Double + val maxLabels = options["maxLabels"] as Int + val path = options["path"] as String + + val localModel = LocalModel.Builder().setAbsoluteFilePath(path).build() + + return CustomObjectDetectorOptions + .Builder(localModel) + .setDetectorMode(mode) + .apply { + if (classify) enableClassification() + if (multiple) enableMultipleObjects() + }.setMaxPerObjectLabelCount(maxLabels) + .setClassificationConfidenceThreshold(threshold.toFloat()) + .build() + } + + private fun addData( + addTo: MutableMap, + trackingId: Int?, + rect: Rect, + labelList: List, + ) { + addTo["rect"] = getBoundingPoints(rect) + addTo["labels"] = + labelList.map { label -> + mapOf( + "index" to label.index, + "text" to label.text, + "confidence" to label.confidence.toDouble(), + ) + } + addTo["trackingId"] = trackingId + } + + private fun getBoundingPoints(rect: Rect) = + mapOf( + "left" to rect.left, + "top" to rect.top, + "right" to rect.right, + "bottom" to rect.bottom, + ) + + private fun closeDetector(call: MethodCall) { + val id = call.argument("id") ?: return + instances.remove(id)?.close() + } + + private fun manageModel( + call: MethodCall, + result: MethodChannel.Result, + ) { + val modelName = + call.argument("model") ?: run { + result.error("ObjectDetectorError", "Model name is null", null) + return + } + val firebaseModelSource = FirebaseModelSource.Builder(modelName).build() + val model = CustomRemoteModel.Builder(firebaseModelSource).build() + genericModelManager.manageModel(model, call, result) + } +} diff --git a/packages/google_mlkit_pose_detection/android/build.gradle b/packages/google_mlkit_pose_detection/android/build.gradle index b96a2955..fab1b8c8 100644 --- a/packages/google_mlkit_pose_detection/android/build.gradle +++ b/packages/google_mlkit_pose_detection/android/build.gradle @@ -2,6 +2,7 @@ group = "com.google_mlkit_pose_detection" version = "1.0" buildscript { + ext.kotlin_version = "2.2.20" repositories { google() mavenCentral() @@ -9,6 +10,7 @@ buildscript { dependencies { classpath("com.android.tools.build:gradle:8.13.0") + classpath('org.jetbrains.kotlin:kotlin-gradle-plugin:2.2.20') } } @@ -20,6 +22,7 @@ rootProject.allprojects { } apply plugin: "com.android.library" +apply plugin: "kotlin-android" android { namespace = "com.google_mlkit_pose_detection" @@ -31,6 +34,14 @@ android { targetCompatibility = JavaVersion.VERSION_11 } + kotlinOptions { + jvmTarget = '11' + } + + sourceSets { + main.java.srcDirs += "src/main/kotlin" + } + defaultConfig { minSdk = 21 } diff --git a/packages/google_mlkit_pose_detection/android/gradle/wrapper/gradle-wrapper.properties b/packages/google_mlkit_pose_detection/android/gradle/wrapper/gradle-wrapper.properties new file mode 100644 index 00000000..128196a7 --- /dev/null +++ b/packages/google_mlkit_pose_detection/android/gradle/wrapper/gradle-wrapper.properties @@ -0,0 +1,7 @@ +distributionBase=GRADLE_USER_HOME +distributionPath=wrapper/dists +distributionUrl=https\://services.gradle.org/distributions/gradle-9.0-milestone-1-bin.zip +networkTimeout=10000 +validateDistributionUrl=true +zipStoreBase=GRADLE_USER_HOME +zipStorePath=wrapper/dists diff --git a/packages/google_mlkit_pose_detection/android/src/main/java/com/google_mlkit_pose_detection/GoogleMlKitPoseDetectionPlugin.java b/packages/google_mlkit_pose_detection/android/src/main/java/com/google_mlkit_pose_detection/GoogleMlKitPoseDetectionPlugin.java deleted file mode 100644 index dc9b6240..00000000 --- a/packages/google_mlkit_pose_detection/android/src/main/java/com/google_mlkit_pose_detection/GoogleMlKitPoseDetectionPlugin.java +++ /dev/null @@ -1,22 +0,0 @@ -package com.google_mlkit_pose_detection; - -import androidx.annotation.NonNull; - -import io.flutter.embedding.engine.plugins.FlutterPlugin; -import io.flutter.plugin.common.MethodChannel; - -public class GoogleMlKitPoseDetectionPlugin implements FlutterPlugin { - private MethodChannel channel; - private static final String channelName = "google_mlkit_pose_detector"; - - @Override - public void onAttachedToEngine(@NonNull FlutterPluginBinding flutterPluginBinding) { - channel = new MethodChannel(flutterPluginBinding.getBinaryMessenger(), channelName); - channel.setMethodCallHandler(new PoseDetector(flutterPluginBinding.getApplicationContext())); - } - - @Override - public void onDetachedFromEngine(@NonNull FlutterPluginBinding binding) { - channel.setMethodCallHandler(null); - } -} diff --git a/packages/google_mlkit_pose_detection/android/src/main/java/com/google_mlkit_pose_detection/PoseDetector.java b/packages/google_mlkit_pose_detection/android/src/main/java/com/google_mlkit_pose_detection/PoseDetector.java deleted file mode 100644 index 63e63f7f..00000000 --- a/packages/google_mlkit_pose_detection/android/src/main/java/com/google_mlkit_pose_detection/PoseDetector.java +++ /dev/null @@ -1,115 +0,0 @@ -package com.google_mlkit_pose_detection; - -import android.content.Context; - -import androidx.annotation.NonNull; - -import com.google.mlkit.vision.common.InputImage; -import com.google.mlkit.vision.pose.PoseDetection; -import com.google.mlkit.vision.pose.PoseLandmark; -import com.google.mlkit.vision.pose.accurate.AccuratePoseDetectorOptions; -import com.google.mlkit.vision.pose.defaults.PoseDetectorOptions; -import com.google_mlkit_commons.InputImageConverter; - -import java.util.ArrayList; -import java.util.HashMap; -import java.util.List; -import java.util.Map; - -import io.flutter.plugin.common.MethodCall; -import io.flutter.plugin.common.MethodChannel; - -public class PoseDetector implements MethodChannel.MethodCallHandler { - private static final String START = "vision#startPoseDetector"; - private static final String CLOSE = "vision#closePoseDetector"; - - private final Context context; - private final Map instances = new HashMap<>(); - - public PoseDetector(Context context) { - this.context = context; - } - - @Override - public void onMethodCall(@NonNull MethodCall call, @NonNull MethodChannel.Result result) { - String method = call.method; - switch (method) { - case START: - handleDetection(call, result); - break; - case CLOSE: - closeDetector(call); - result.success(null); - break; - default: - result.notImplemented(); - break; - } - } - - private void handleDetection(MethodCall call, final MethodChannel.Result result) { - Map imageData = (Map) call.argument("imageData"); - InputImage inputImage = InputImageConverter.getInputImageFromData(imageData, context, result); - if (inputImage == null) return; - - String id = call.argument("id"); - com.google.mlkit.vision.pose.PoseDetector poseDetector = instances.get(id); - if (poseDetector == null) { - Map options = call.argument("options"); - if (options == null) { - result.error("PoseDetectorError", "Invalid options", null); - return; - } - - String mode = (String) options.get("mode"); - int detectorMode = PoseDetectorOptions.STREAM_MODE; - if (mode.equals("single")) { - detectorMode = PoseDetectorOptions.SINGLE_IMAGE_MODE; - } - - String model = (String) options.get("model"); - if (model.equals("base")) { - PoseDetectorOptions detectorOptions = new PoseDetectorOptions.Builder() - .setDetectorMode(detectorMode) - .build(); - poseDetector = PoseDetection.getClient(detectorOptions); - } else { - AccuratePoseDetectorOptions detectorOptions = new AccuratePoseDetectorOptions.Builder() - .setDetectorMode(detectorMode) - .build(); - poseDetector = PoseDetection.getClient(detectorOptions); - } - instances.put(id, poseDetector); - } - - poseDetector.process(inputImage) - .addOnSuccessListener( - pose -> { - List>> array = new ArrayList<>(); - if (!pose.getAllPoseLandmarks().isEmpty()) { - List> landmarks = new ArrayList<>(); - for (PoseLandmark poseLandmark : pose.getAllPoseLandmarks()) { - Map landmarkMap = new HashMap<>(); - landmarkMap.put("type", poseLandmark.getLandmarkType()); - landmarkMap.put("x", poseLandmark.getPosition3D().getX()); - landmarkMap.put("y", poseLandmark.getPosition3D().getY()); - landmarkMap.put("z", poseLandmark.getPosition3D().getZ()); - landmarkMap.put("likelihood", poseLandmark.getInFrameLikelihood()); - landmarks.add(landmarkMap); - } - array.add(landmarks); - } - result.success(array); - }) - .addOnFailureListener( - e -> result.error("PoseDetectorError", e.toString(), null)); - } - - private void closeDetector(MethodCall call) { - String id = call.argument("id"); - com.google.mlkit.vision.pose.PoseDetector poseDetector = instances.get(id); - if (poseDetector == null) return; - poseDetector.close(); - instances.remove(id); - } -} diff --git a/packages/google_mlkit_pose_detection/android/src/main/kotlin/com/google_mlkit_pose_detection/GoogleMlKitPoseDetectionPlugin.kt b/packages/google_mlkit_pose_detection/android/src/main/kotlin/com/google_mlkit_pose_detection/GoogleMlKitPoseDetectionPlugin.kt new file mode 100644 index 00000000..c838126f --- /dev/null +++ b/packages/google_mlkit_pose_detection/android/src/main/kotlin/com/google_mlkit_pose_detection/GoogleMlKitPoseDetectionPlugin.kt @@ -0,0 +1,21 @@ +package com.google_mlkit_pose_detection + +import io.flutter.embedding.engine.plugins.FlutterPlugin +import io.flutter.plugin.common.MethodChannel + +class GoogleMlKitPoseDetectionPlugin : FlutterPlugin { + private lateinit var channel: MethodChannel + + companion object { + private const val CHANNEL_NAME = "google_mlkit_pose_detector" + } + + override fun onAttachedToEngine(binding: FlutterPlugin.FlutterPluginBinding) { + channel = MethodChannel(binding.binaryMessenger, CHANNEL_NAME) + channel.setMethodCallHandler(PoseDetector(binding.applicationContext)) + } + + override fun onDetachedFromEngine(binding: FlutterPlugin.FlutterPluginBinding) { + channel.setMethodCallHandler(null) + } +} diff --git a/packages/google_mlkit_pose_detection/android/src/main/kotlin/com/google_mlkit_pose_detection/PoseDetector.kt b/packages/google_mlkit_pose_detection/android/src/main/kotlin/com/google_mlkit_pose_detection/PoseDetector.kt new file mode 100644 index 00000000..7539bc06 --- /dev/null +++ b/packages/google_mlkit_pose_detection/android/src/main/kotlin/com/google_mlkit_pose_detection/PoseDetector.kt @@ -0,0 +1,117 @@ +package com.google_mlkit_pose_detection + +import android.content.Context +import com.google.mlkit.vision.pose.PoseDetection +import com.google.mlkit.vision.pose.PoseLandmark +import com.google.mlkit.vision.pose.accurate.AccuratePoseDetectorOptions +import com.google.mlkit.vision.pose.defaults.PoseDetectorOptions +import com.google_mlkit_commons.InputImageConverter +import io.flutter.plugin.common.MethodCall +import io.flutter.plugin.common.MethodChannel + +class PoseDetector( + private val context: Context, +) : MethodChannel.MethodCallHandler { + private val instances = mutableMapOf() + + companion object { + private const val START = "vision#startPoseDetector" + private const val CLOSE = "vision#closePoseDetector" + } + + override fun onMethodCall( + call: MethodCall, + result: MethodChannel.Result, + ) { + when (call.method) { + START -> { + handleDetection(call, result) + } + + CLOSE -> { + closeDetector(call) + result.success(null) + } + + else -> { + result.notImplemented() + } + } + } + + private fun handleDetection( + call: MethodCall, + result: MethodChannel.Result, + ) { + val imageData = + call.argument>("imageData") ?: run { + result.error("PoseDetectorError", "imageData is null", null) + return + } + + val inputImage = InputImageConverter.getInputImageFromData(imageData, context, result) ?: return + + val id = call.argument("id") ?: return + var poseDetector = instances[id] + + if (poseDetector == null) { + val options = + call.argument>("options") ?: run { + result.error("PoseDetectorError", "Invalid options", null) + return + } + + val detectorMode = + if (options["mode"] as? String == "single") { + PoseDetectorOptions.SINGLE_IMAGE_MODE + } else { + PoseDetectorOptions.STREAM_MODE + } + + poseDetector = + if (options["model"] as? String == "base") { + PoseDetection.getClient( + PoseDetectorOptions + .Builder() + .setDetectorMode(detectorMode) + .build(), + ) + } else { + PoseDetection.getClient( + AccuratePoseDetectorOptions + .Builder() + .setDetectorMode(detectorMode) + .build(), + ) + } + instances[id] = poseDetector + } + + poseDetector + .process(inputImage) + .addOnSuccessListener { pose -> + val array = mutableListOf>>() + if (pose.allPoseLandmarks.isNotEmpty()) { + val landmarks = + pose.allPoseLandmarks.map { poseLandmark -> + mapOf( + "type" to poseLandmark.landmarkType, + "x" to poseLandmark.position3D.x, + "y" to poseLandmark.position3D.y, + "z" to poseLandmark.position3D.z, + "likelihood" to poseLandmark.inFrameLikelihood, + ) + } + array.add(landmarks) + } + result.success(array) + }.addOnFailureListener { e -> + result.error("PoseDetectorError", e.toString(), null) + } + } + + private fun closeDetector(call: MethodCall) { + val id = call.argument("id") ?: return + instances.remove(id)?.close() + } +} diff --git a/packages/google_mlkit_selfie_segmentation/android/build.gradle b/packages/google_mlkit_selfie_segmentation/android/build.gradle index 919212f1..cd4959b5 100644 --- a/packages/google_mlkit_selfie_segmentation/android/build.gradle +++ b/packages/google_mlkit_selfie_segmentation/android/build.gradle @@ -2,6 +2,7 @@ group = "com.google_mlkit_selfie_segmentation" version = "1.0" buildscript { + ext.kotlin_version = "2.2.20" repositories { google() mavenCentral() @@ -9,6 +10,7 @@ buildscript { dependencies { classpath("com.android.tools.build:gradle:8.13.0") + classpath('org.jetbrains.kotlin:kotlin-gradle-plugin:2.2.20') } } @@ -20,6 +22,7 @@ rootProject.allprojects { } apply plugin: "com.android.library" +apply plugin: "kotlin-android" android { namespace = "com.google_mlkit_selfie_segmentation" @@ -31,6 +34,14 @@ android { targetCompatibility = JavaVersion.VERSION_11 } + kotlinOptions { + jvmTarget = '11' + } + + sourceSets { + main.java.srcDirs += "src/main/kotlin" + } + defaultConfig { minSdk = 21 } diff --git a/packages/google_mlkit_selfie_segmentation/android/gradle/wrapper/gradle-wrapper.properties b/packages/google_mlkit_selfie_segmentation/android/gradle/wrapper/gradle-wrapper.properties new file mode 100644 index 00000000..128196a7 --- /dev/null +++ b/packages/google_mlkit_selfie_segmentation/android/gradle/wrapper/gradle-wrapper.properties @@ -0,0 +1,7 @@ +distributionBase=GRADLE_USER_HOME +distributionPath=wrapper/dists +distributionUrl=https\://services.gradle.org/distributions/gradle-9.0-milestone-1-bin.zip +networkTimeout=10000 +validateDistributionUrl=true +zipStoreBase=GRADLE_USER_HOME +zipStorePath=wrapper/dists diff --git a/packages/google_mlkit_selfie_segmentation/android/src/main/java/com/google_mlkit_selfie_segmentation/GoogleMlKitSelfieSegmentationPlugin.java b/packages/google_mlkit_selfie_segmentation/android/src/main/java/com/google_mlkit_selfie_segmentation/GoogleMlKitSelfieSegmentationPlugin.java deleted file mode 100644 index e623a98f..00000000 --- a/packages/google_mlkit_selfie_segmentation/android/src/main/java/com/google_mlkit_selfie_segmentation/GoogleMlKitSelfieSegmentationPlugin.java +++ /dev/null @@ -1,22 +0,0 @@ -package com.google_mlkit_selfie_segmentation; - -import androidx.annotation.NonNull; - -import io.flutter.embedding.engine.plugins.FlutterPlugin; -import io.flutter.plugin.common.MethodChannel; - -public class GoogleMlKitSelfieSegmentationPlugin implements FlutterPlugin { - private MethodChannel channel; - private static final String channelName = "google_mlkit_selfie_segmenter"; - - @Override - public void onAttachedToEngine(@NonNull FlutterPluginBinding flutterPluginBinding) { - channel = new MethodChannel(flutterPluginBinding.getBinaryMessenger(), channelName); - channel.setMethodCallHandler(new SelfieSegmenter(flutterPluginBinding.getApplicationContext())); - } - - @Override - public void onDetachedFromEngine(@NonNull FlutterPluginBinding binding) { - channel.setMethodCallHandler(null); - } -} diff --git a/packages/google_mlkit_selfie_segmentation/android/src/main/java/com/google_mlkit_selfie_segmentation/SelfieSegmenter.java b/packages/google_mlkit_selfie_segmentation/android/src/main/java/com/google_mlkit_selfie_segmentation/SelfieSegmenter.java deleted file mode 100644 index 478a4100..00000000 --- a/packages/google_mlkit_selfie_segmentation/android/src/main/java/com/google_mlkit_selfie_segmentation/SelfieSegmenter.java +++ /dev/null @@ -1,115 +0,0 @@ -package com.google_mlkit_selfie_segmentation; - -import android.content.Context; - -import androidx.annotation.NonNull; - -import com.google.mlkit.vision.common.InputImage; -import com.google.mlkit.vision.segmentation.Segmentation; -import com.google.mlkit.vision.segmentation.Segmenter; -import com.google.mlkit.vision.segmentation.selfie.SelfieSegmenterOptions; -import com.google_mlkit_commons.InputImageConverter; - -import java.nio.ByteBuffer; -import java.util.HashMap; -import java.util.Map; - -import io.flutter.plugin.common.MethodCall; -import io.flutter.plugin.common.MethodChannel; - -public class SelfieSegmenter implements MethodChannel.MethodCallHandler { - private static final String START = "vision#startSelfieSegmenter"; - private static final String CLOSE = "vision#closeSelfieSegmenter"; - - private final Context context; - private final Map instances = new HashMap<>(); - - public SelfieSegmenter(Context context) { - this.context = context; - } - - @Override - public void onMethodCall(@NonNull MethodCall call, @NonNull MethodChannel.Result result) { - String method = call.method; - switch (method) { - case START: - handleDetection(call, result); - break; - case CLOSE: - closeDetector(call); - result.success(null); - break; - default: - result.notImplemented(); - break; - } - } - - private Segmenter initialize(MethodCall call) { - Boolean isStream = call.argument("isStream"); - Boolean enableRawSizeMask = call.argument("enableRawSizeMask"); - - SelfieSegmenterOptions.Builder builder = new SelfieSegmenterOptions.Builder(); - - builder.setDetectorMode(isStream - ? SelfieSegmenterOptions.STREAM_MODE - : SelfieSegmenterOptions.SINGLE_IMAGE_MODE); - - if (enableRawSizeMask) { - builder.enableRawSizeMask(); - } - - SelfieSegmenterOptions options = builder.build(); - return Segmentation.getClient(options); - } - - private void handleDetection(MethodCall call, final MethodChannel.Result result) { - Map imageData = (Map) call.argument("imageData"); - InputImage inputImage = InputImageConverter.getInputImageFromData(imageData, context, result); - if (inputImage == null) return; - - String id = call.argument("id"); - Segmenter segmenter = instances.get(id); - if (segmenter == null) { - segmenter = initialize(call); - instances.put(id, segmenter); - } - - segmenter.process(inputImage) - .addOnSuccessListener( - segmentationMask -> { - Map map = new HashMap<>(); - ByteBuffer mask = segmentationMask.getBuffer(); - int maskWidth = segmentationMask.getWidth(); - int maskHeight = segmentationMask.getHeight(); - - map.put("width", maskWidth); - map.put("height", maskHeight); - - final float[] confidences = new float[maskWidth * maskHeight]; -// mask.asFloatBuffer().get(confidences, 0, confidences.length); - - for (int y = 0; y < maskHeight; y++) { - for (int x = 0; x < maskWidth; x++) { - // Gets the confidence of the (x,y) pixel in the mask being in the foreground. - // float foregroundConfidence = mask.getFloat(); - confidences[y * maskWidth + x] = mask.getFloat(); - } - } - - map.put("confidences", confidences); - - result.success(map); - }) - .addOnFailureListener( - e -> result.error("Selfie segmentation failed!", e.getMessage(), e)); - } - - private void closeDetector(MethodCall call) { - String id = call.argument("id"); - Segmenter segmenter = instances.get(id); - if (segmenter == null) return; - segmenter.close(); - instances.remove(id); - } -} diff --git a/packages/google_mlkit_selfie_segmentation/android/src/main/kotlin/com/google_mlkit_selfie_segmentation/GoogleMlKitSelfieSegmentationPlugin.kt b/packages/google_mlkit_selfie_segmentation/android/src/main/kotlin/com/google_mlkit_selfie_segmentation/GoogleMlKitSelfieSegmentationPlugin.kt new file mode 100644 index 00000000..ba054810 --- /dev/null +++ b/packages/google_mlkit_selfie_segmentation/android/src/main/kotlin/com/google_mlkit_selfie_segmentation/GoogleMlKitSelfieSegmentationPlugin.kt @@ -0,0 +1,21 @@ +package com.google_mlkit_selfie_segmentation + +import io.flutter.embedding.engine.plugins.FlutterPlugin +import io.flutter.plugin.common.MethodChannel + +class GoogleMlKitSelfieSegmentationPlugin : FlutterPlugin { + private lateinit var channel: MethodChannel + + companion object { + private const val CHANNEL_NAME = "google_mlkit_selfie_segmenter" + } + + override fun onAttachedToEngine(binding: FlutterPlugin.FlutterPluginBinding) { + channel = MethodChannel(binding.binaryMessenger, CHANNEL_NAME) + channel.setMethodCallHandler(SelfieSegmenter(binding.applicationContext)) + } + + override fun onDetachedFromEngine(binding: FlutterPlugin.FlutterPluginBinding) { + channel.setMethodCallHandler(null) + } +} diff --git a/packages/google_mlkit_selfie_segmentation/android/src/main/kotlin/com/google_mlkit_selfie_segmentation/SelfieSegmenter.kt b/packages/google_mlkit_selfie_segmentation/android/src/main/kotlin/com/google_mlkit_selfie_segmentation/SelfieSegmenter.kt new file mode 100644 index 00000000..87275033 --- /dev/null +++ b/packages/google_mlkit_selfie_segmentation/android/src/main/kotlin/com/google_mlkit_selfie_segmentation/SelfieSegmenter.kt @@ -0,0 +1,104 @@ +package com.google_mlkit_selfie_segmentation + +import android.content.Context +import com.google.mlkit.vision.segmentation.Segmentation +import com.google.mlkit.vision.segmentation.Segmenter +import com.google.mlkit.vision.segmentation.selfie.SelfieSegmenterOptions +import com.google_mlkit_commons.InputImageConverter +import io.flutter.plugin.common.MethodCall +import io.flutter.plugin.common.MethodChannel + +class SelfieSegmenter( + private val context: Context, +) : MethodChannel.MethodCallHandler { + private val instances = mutableMapOf() + + companion object { + private const val START = "vision#startSelfieSegmenter" + private const val CLOSE = "vision#closeSelfieSegmenter" + } + + override fun onMethodCall( + call: MethodCall, + result: MethodChannel.Result, + ) { + when (call.method) { + START -> { + handleDetection(call, result) + } + + CLOSE -> { + closeDetector(call) + result.success(null) + } + + else -> { + result.notImplemented() + } + } + } + + private fun initialize(call: MethodCall): Segmenter { + val isStream = call.argument("isStream") ?: false + val enableRawSizeMask = call.argument("enableRawSizeMask") ?: false + + val options = + SelfieSegmenterOptions + .Builder() + .setDetectorMode( + if (isStream) { + SelfieSegmenterOptions.STREAM_MODE + } else { + SelfieSegmenterOptions.SINGLE_IMAGE_MODE + }, + ).apply { if (enableRawSizeMask) enableRawSizeMask() } + .build() + + return Segmentation.getClient(options) + } + + private fun handleDetection( + call: MethodCall, + result: MethodChannel.Result, + ) { + val imageData = + call.argument>("imageData") ?: run { + result.error("SelfieSegmenterError", "imageData is null", null) + return + } + val inputImage = InputImageConverter.getInputImageFromData(imageData, context, result) ?: return + + val id = call.argument("id") ?: return + val segmenter = instances.getOrPut(id) { initialize(call) } + + segmenter + .process(inputImage) + .addOnSuccessListener { segmentationMask -> + val mask = segmentationMask.buffer + val maskWidth = segmentationMask.width + val maskHeight = segmentationMask.height + + val confidences = + FloatArray(maskWidth * maskHeight) { i -> + val y = i / maskWidth + val x = i % maskWidth + mask.getFloat() + } + + result.success( + mapOf( + "width" to maskWidth, + "height" to maskHeight, + "confidences" to confidences, + ), + ) + }.addOnFailureListener { e -> + result.error("Selfie segmentation failed!", e.message, e) + } + } + + private fun closeDetector(call: MethodCall) { + val id = call.argument("id") ?: return + instances.remove(id)?.close() + } +} diff --git a/packages/google_mlkit_smart_reply/android/build.gradle b/packages/google_mlkit_smart_reply/android/build.gradle index 812c883d..e4d02a75 100644 --- a/packages/google_mlkit_smart_reply/android/build.gradle +++ b/packages/google_mlkit_smart_reply/android/build.gradle @@ -2,6 +2,7 @@ group = "com.google_mlkit_smart_reply" version = "1.0" buildscript { + ext.kotlin_version = "2.2.20" repositories { google() mavenCentral() @@ -9,6 +10,7 @@ buildscript { dependencies { classpath("com.android.tools.build:gradle:8.13.0") + classpath('org.jetbrains.kotlin:kotlin-gradle-plugin:2.2.20') } } @@ -20,6 +22,7 @@ rootProject.allprojects { } apply plugin: "com.android.library" +apply plugin: "kotlin-android" android { namespace = "com.google_mlkit_smart_reply" @@ -31,6 +34,14 @@ android { targetCompatibility = JavaVersion.VERSION_11 } + kotlinOptions { + jvmTarget = '11' + } + + sourceSets { + main.java.srcDirs += "src/main/kotlin" + } + defaultConfig { minSdk = 21 } diff --git a/packages/google_mlkit_smart_reply/android/gradle/wrapper/gradle-wrapper.properties b/packages/google_mlkit_smart_reply/android/gradle/wrapper/gradle-wrapper.properties new file mode 100644 index 00000000..128196a7 --- /dev/null +++ b/packages/google_mlkit_smart_reply/android/gradle/wrapper/gradle-wrapper.properties @@ -0,0 +1,7 @@ +distributionBase=GRADLE_USER_HOME +distributionPath=wrapper/dists +distributionUrl=https\://services.gradle.org/distributions/gradle-9.0-milestone-1-bin.zip +networkTimeout=10000 +validateDistributionUrl=true +zipStoreBase=GRADLE_USER_HOME +zipStorePath=wrapper/dists diff --git a/packages/google_mlkit_smart_reply/android/src/main/java/com/google_mlkit_smart_reply/GoogleMlKitSmartReplyPlugin.java b/packages/google_mlkit_smart_reply/android/src/main/java/com/google_mlkit_smart_reply/GoogleMlKitSmartReplyPlugin.java deleted file mode 100644 index 270591a9..00000000 --- a/packages/google_mlkit_smart_reply/android/src/main/java/com/google_mlkit_smart_reply/GoogleMlKitSmartReplyPlugin.java +++ /dev/null @@ -1,22 +0,0 @@ -package com.google_mlkit_smart_reply; - -import androidx.annotation.NonNull; - -import io.flutter.embedding.engine.plugins.FlutterPlugin; -import io.flutter.plugin.common.MethodChannel; - -public class GoogleMlKitSmartReplyPlugin implements FlutterPlugin { - private MethodChannel channel; - private static final String channelName = "google_mlkit_smart_reply"; - - @Override - public void onAttachedToEngine(@NonNull FlutterPluginBinding flutterPluginBinding) { - channel = new MethodChannel(flutterPluginBinding.getBinaryMessenger(), channelName); - channel.setMethodCallHandler(new SmartReply()); - } - - @Override - public void onDetachedFromEngine(@NonNull FlutterPluginBinding binding) { - channel.setMethodCallHandler(null); - } -} diff --git a/packages/google_mlkit_smart_reply/android/src/main/java/com/google_mlkit_smart_reply/SmartReply.java b/packages/google_mlkit_smart_reply/android/src/main/java/com/google_mlkit_smart_reply/SmartReply.java deleted file mode 100644 index 54eaddf7..00000000 --- a/packages/google_mlkit_smart_reply/android/src/main/java/com/google_mlkit_smart_reply/SmartReply.java +++ /dev/null @@ -1,89 +0,0 @@ -package com.google_mlkit_smart_reply; - -import androidx.annotation.NonNull; - -import com.google.mlkit.nl.smartreply.SmartReplyGenerator; -import com.google.mlkit.nl.smartreply.SmartReplySuggestion; -import com.google.mlkit.nl.smartreply.SmartReplySuggestionResult; -import com.google.mlkit.nl.smartreply.TextMessage; - -import java.util.ArrayList; -import java.util.HashMap; -import java.util.List; -import java.util.Map; - -import io.flutter.plugin.common.MethodCall; -import io.flutter.plugin.common.MethodChannel; - -public class SmartReply implements MethodChannel.MethodCallHandler { - private static final String START = "nlp#startSmartReply"; - private static final String CLOSE = "nlp#closeSmartReply"; - - private final Map instances = new HashMap<>(); - - @Override - public void onMethodCall(@NonNull MethodCall call, @NonNull MethodChannel.Result result) { - String methodCall = call.method; - switch (methodCall) { - case START: - suggestReply(call, result); - break; - case CLOSE: - closeDetector(call); - result.success(null); - break; - default: - result.notImplemented(); - break; - } - } - - private void suggestReply(@NonNull MethodCall call, @NonNull MethodChannel.Result result) { - List conversation = new ArrayList<>(); - List> json = call.argument("conversation"); - - for (final Map object : json) { - String message = (String) object.get("message"); - long timestamp = (long) object.get("timestamp"); - String userId = (String) object.get("userId"); - if (userId.equals("local")) { - conversation.add(TextMessage.createForLocalUser(message, - timestamp)); - } else { - conversation.add(TextMessage.createForRemoteUser(message, - timestamp, userId)); - } - } - - String id = call.argument("id"); - SmartReplyGenerator smartReplyGenerator = instances.get(id); - if (smartReplyGenerator == null) { - smartReplyGenerator = com.google.mlkit.nl.smartreply.SmartReply.getClient(); - instances.put(id, smartReplyGenerator); - } - - smartReplyGenerator.suggestReplies(conversation) - .addOnSuccessListener(smartReplySuggestionResult -> { - int status = smartReplySuggestionResult.getStatus(); - Map suggestionResult = new HashMap<>(); - suggestionResult.put("status", status); - if (status == SmartReplySuggestionResult.STATUS_SUCCESS) { - List suggestions = new ArrayList<>(); - for (SmartReplySuggestion suggestion : smartReplySuggestionResult.getSuggestions()) { - suggestions.add(suggestion.getText()); - } - suggestionResult.put("suggestions", suggestions); - } - result.success(suggestionResult); - }) - .addOnFailureListener(e -> result.error("failed suggesting", e.toString(), null)); - } - - private void closeDetector(MethodCall call) { - String id = call.argument("id"); - SmartReplyGenerator smartReplyGenerator = instances.get(id); - if (smartReplyGenerator == null) return; - smartReplyGenerator.close(); - instances.remove(id); - } -} diff --git a/packages/google_mlkit_smart_reply/android/src/main/kotlin/com/google_mlkit_smart_reply/GoogleMlKitSmartReplyPlugin.kt b/packages/google_mlkit_smart_reply/android/src/main/kotlin/com/google_mlkit_smart_reply/GoogleMlKitSmartReplyPlugin.kt new file mode 100644 index 00000000..62afbaa9 --- /dev/null +++ b/packages/google_mlkit_smart_reply/android/src/main/kotlin/com/google_mlkit_smart_reply/GoogleMlKitSmartReplyPlugin.kt @@ -0,0 +1,21 @@ +package com.google_mlkit_smart_reply + +import io.flutter.embedding.engine.plugins.FlutterPlugin +import io.flutter.plugin.common.MethodChannel + +class GoogleMlKitSmartReplyPlugin : FlutterPlugin { + private lateinit var channel: MethodChannel + + companion object { + private const val CHANNEL_NAME = "google_mlkit_smart_reply" + } + + override fun onAttachedToEngine(binding: FlutterPlugin.FlutterPluginBinding) { + channel = MethodChannel(binding.binaryMessenger, CHANNEL_NAME) + channel.setMethodCallHandler(SmartReplyHandler()) + } + + override fun onDetachedFromEngine(binding: FlutterPlugin.FlutterPluginBinding) { + channel.setMethodCallHandler(null) + } +} diff --git a/packages/google_mlkit_smart_reply/android/src/main/kotlin/com/google_mlkit_smart_reply/SmartReplyHandler.kt b/packages/google_mlkit_smart_reply/android/src/main/kotlin/com/google_mlkit_smart_reply/SmartReplyHandler.kt new file mode 100644 index 00000000..ac833c60 --- /dev/null +++ b/packages/google_mlkit_smart_reply/android/src/main/kotlin/com/google_mlkit_smart_reply/SmartReplyHandler.kt @@ -0,0 +1,77 @@ +package com.google_mlkit_smart_reply + +import com.google.mlkit.nl.smartreply.SmartReply +import com.google.mlkit.nl.smartreply.SmartReplyGenerator +import com.google.mlkit.nl.smartreply.SmartReplySuggestionResult +import com.google.mlkit.nl.smartreply.TextMessage +import io.flutter.plugin.common.MethodCall +import io.flutter.plugin.common.MethodChannel + +class SmartReplyHandler : MethodChannel.MethodCallHandler { + private val instances = mutableMapOf() + + companion object { + private const val START = "nlp#startSmartReply" + private const val CLOSE = "nlp#closeSmartReply" + } + + override fun onMethodCall( + call: MethodCall, + result: MethodChannel.Result, + ) { + when (call.method) { + START -> { + suggestReply(call, result) + } + + CLOSE -> { + closeDetector(call) + result.success(null) + } + + else -> { + result.notImplemented() + } + } + } + + private fun suggestReply( + call: MethodCall, + result: MethodChannel.Result, + ) { + val json = call.argument>>("conversation") ?: return + + val conversation = + json.map { obj -> + val message = obj["message"] as String + val timestamp = obj["timestamp"] as Long + val userId = obj["userId"] as String + if (userId == "local") { + TextMessage.createForLocalUser(message, timestamp) + } else { + TextMessage.createForRemoteUser(message, timestamp, userId) + } + } + + val id = call.argument("id") ?: return + val smartReplyGenerator = instances.getOrPut(id) { SmartReply.getClient() } + + smartReplyGenerator + .suggestReplies(conversation) + .addOnSuccessListener { smartReplySuggestionResult -> + val status = smartReplySuggestionResult.status + val suggestionResult = mutableMapOf("status" to status) + if (status == SmartReplySuggestionResult.STATUS_SUCCESS) { + suggestionResult["suggestions"] = smartReplySuggestionResult.suggestions.map { it.text } + } + result.success(suggestionResult) + }.addOnFailureListener { e -> + result.error("failed suggesting", e.toString(), null) + } + } + + private fun closeDetector(call: MethodCall) { + val id = call.argument("id") ?: return + instances.remove(id)?.close() + } +} diff --git a/packages/google_mlkit_subject_segmentation/android/build.gradle b/packages/google_mlkit_subject_segmentation/android/build.gradle index ce6791b0..dc49494e 100644 --- a/packages/google_mlkit_subject_segmentation/android/build.gradle +++ b/packages/google_mlkit_subject_segmentation/android/build.gradle @@ -2,6 +2,7 @@ group = "com.google_mlkit_subject_segmentation" version = "1.0" buildscript { + ext.kotlin_version = "2.2.20" repositories { google() mavenCentral() @@ -9,6 +10,7 @@ buildscript { dependencies { classpath("com.android.tools.build:gradle:8.13.0") + classpath('org.jetbrains.kotlin:kotlin-gradle-plugin:2.2.20') } } @@ -20,6 +22,7 @@ rootProject.allprojects { } apply plugin: "com.android.library" +apply plugin: "kotlin-android" android { namespace = "com.google_mlkit_subject_segmentation" @@ -31,6 +34,14 @@ android { targetCompatibility = JavaVersion.VERSION_11 } + kotlinOptions { + jvmTarget = '11' + } + + sourceSets { + main.java.srcDirs += "src/main/kotlin" + } + defaultConfig { minSdk = 24 } diff --git a/packages/google_mlkit_subject_segmentation/android/gradle/wrapper/gradle-wrapper.properties b/packages/google_mlkit_subject_segmentation/android/gradle/wrapper/gradle-wrapper.properties index 62f495df..128196a7 100644 --- a/packages/google_mlkit_subject_segmentation/android/gradle/wrapper/gradle-wrapper.properties +++ b/packages/google_mlkit_subject_segmentation/android/gradle/wrapper/gradle-wrapper.properties @@ -1,6 +1,6 @@ distributionBase=GRADLE_USER_HOME distributionPath=wrapper/dists -distributionUrl=https\://services.gradle.org/distributions/gradle-8.2-bin.zip +distributionUrl=https\://services.gradle.org/distributions/gradle-9.0-milestone-1-bin.zip networkTimeout=10000 validateDistributionUrl=true zipStoreBase=GRADLE_USER_HOME diff --git a/packages/google_mlkit_subject_segmentation/android/src/main/java/com/google_mlkit_subject_segmentation/GoogleMlKitSubjectSegmentationPlugin.java b/packages/google_mlkit_subject_segmentation/android/src/main/java/com/google_mlkit_subject_segmentation/GoogleMlKitSubjectSegmentationPlugin.java deleted file mode 100644 index cb5c629f..00000000 --- a/packages/google_mlkit_subject_segmentation/android/src/main/java/com/google_mlkit_subject_segmentation/GoogleMlKitSubjectSegmentationPlugin.java +++ /dev/null @@ -1,22 +0,0 @@ -package com.google_mlkit_subject_segmentation; - -import androidx.annotation.NonNull; - -import io.flutter.embedding.engine.plugins.FlutterPlugin; -import io.flutter.plugin.common.MethodChannel; - -public class GoogleMlKitSubjectSegmentationPlugin implements FlutterPlugin { - private MethodChannel channel; - private static final String channelName = "google_mlkit_subject_segmentation"; - - @Override - public void onAttachedToEngine(@NonNull FlutterPluginBinding flutterPluginBinding) { - channel = new MethodChannel(flutterPluginBinding.getBinaryMessenger(), channelName); - channel.setMethodCallHandler(new SubjectSegmenter(flutterPluginBinding.getApplicationContext())); - } - - @Override - public void onDetachedFromEngine(@NonNull FlutterPluginBinding binding) { - channel.setMethodCallHandler(null); - } -} diff --git a/packages/google_mlkit_subject_segmentation/android/src/main/java/com/google_mlkit_subject_segmentation/SubjectSegmenter.java b/packages/google_mlkit_subject_segmentation/android/src/main/java/com/google_mlkit_subject_segmentation/SubjectSegmenter.java deleted file mode 100644 index f6ed9ed7..00000000 --- a/packages/google_mlkit_subject_segmentation/android/src/main/java/com/google_mlkit_subject_segmentation/SubjectSegmenter.java +++ /dev/null @@ -1,154 +0,0 @@ -package com.google_mlkit_subject_segmentation; - -import android.content.Context; -import android.graphics.Bitmap; - -import androidx.annotation.NonNull; - -import com.google.mlkit.vision.common.InputImage; -import com.google.mlkit.vision.segmentation.subject.Subject; -import com.google.mlkit.vision.segmentation.subject.SubjectSegmentation; -import com.google.mlkit.vision.segmentation.subject.SubjectSegmentationResult; - -import java.io.ByteArrayOutputStream; -import java.util.ArrayList; -import java.util.List; -import java.nio.FloatBuffer; -import java.util.HashMap; -import java.util.Map; - -import io.flutter.plugin.common.MethodCall; -import io.flutter.plugin.common.MethodChannel; - -import com.google.mlkit.vision.segmentation.subject.SubjectSegmenterOptions; -import com.google_mlkit_commons.InputImageConverter; - -public class SubjectSegmenter implements MethodChannel.MethodCallHandler { - private static final String START = "vision#startSubjectSegmenter"; - private static final String CLOSE = "vision#closeSubjectSegmenter"; - - private final Context context; - - private final Map instances = new HashMap<>(); - - public SubjectSegmenter(Context context) { - this.context = context; - } - - @Override - public void onMethodCall(@NonNull MethodCall call, @NonNull MethodChannel.Result result) { - String method = call.method; - switch (method) { - case START: - handleDetection(call, result); - break; - case CLOSE: - closeDetector(call); - result.success(null); - break; - default: - result.notImplemented(); - break; - } - } - - private void handleDetection(MethodCall call, MethodChannel.Result result) { - Map imageData = (Map) call.argument("imageData"); - InputImage inputImage = InputImageConverter.getInputImageFromData(imageData, context, result); - if (inputImage == null) return; - - String id = call.argument("id"); - com.google.mlkit.vision.segmentation.subject.SubjectSegmenter subjectSegmenter = getOrCreateSegmenter(id, call); - subjectSegmenter.process(inputImage).addOnSuccessListener(subjectSegmentationResult -> processResult(subjectSegmentationResult, result)).addOnFailureListener(e -> result.error("Subject segmentation failure!", e.getMessage(), e)); - } - - private com.google.mlkit.vision.segmentation.subject.SubjectSegmenter getOrCreateSegmenter(String id, MethodCall call) { - return instances.computeIfAbsent(id, k -> initialize(call)); - } - - private com.google.mlkit.vision.segmentation.subject.SubjectSegmenter initialize(MethodCall call) { - Map options = call.argument("options"); - SubjectSegmenterOptions.Builder builder = new SubjectSegmenterOptions.Builder(); - assert options != null; - configureBuilder(builder, options); - return SubjectSegmentation.getClient(builder.build()); - } - - private void configureBuilder(SubjectSegmenterOptions.Builder builder, Map options) { - if (Boolean.TRUE.equals(options.get("enableForegroundBitmap"))) { - builder.enableForegroundBitmap(); - } - if (Boolean.TRUE.equals(options.get("enableForegroundConfidenceMask"))) { - builder.enableForegroundConfidenceMask(); - } - configureMultipleSubjects(builder, (Map) options.get("enableMultiSubjectBitmap")); - } - - private void configureMultipleSubjects(SubjectSegmenterOptions.Builder builder, Map options) { - boolean enableConfidenceMask = Boolean.TRUE.equals(options.get("enableConfidenceMask")); - boolean enableSubjectBitmap = Boolean.TRUE.equals(options.get("enableSubjectBitmap")); - SubjectSegmenterOptions.SubjectResultOptions.Builder subjectResultOptionsBuilder = new SubjectSegmenterOptions.SubjectResultOptions.Builder(); - if (enableConfidenceMask) subjectResultOptionsBuilder.enableConfidenceMask(); - if (enableSubjectBitmap) subjectResultOptionsBuilder.enableSubjectBitmap(); - if (enableConfidenceMask || enableSubjectBitmap) { - builder.enableMultipleSubjects(subjectResultOptionsBuilder.build()); - } - } - - private void processResult(SubjectSegmentationResult subjectSegmentationResult, MethodChannel.Result result) { - Map resultMap = new HashMap<>(); - FloatBuffer foregroundConfidenceMask = subjectSegmentationResult.getForegroundConfidenceMask(); - if (foregroundConfidenceMask != null) { - resultMap.put("foregroundConfidenceMask", getConfidenceMask(foregroundConfidenceMask)); - } - Bitmap foregroundBitmap = subjectSegmentationResult.getForegroundBitmap(); - if (foregroundBitmap != null) { - resultMap.put("foregroundBitmap", getBitmapBytes(foregroundBitmap)); - } - List> subjectsData = new ArrayList<>(); - for (Subject subject : subjectSegmentationResult.getSubjects()) { - Map subjectData = getStringObjectMap(subject); - subjectsData.add(subjectData); - } - resultMap.put("subjects", subjectsData); - result.success(resultMap); - } - - private static float[] getConfidenceMask(FloatBuffer floatBuffer) { - float[] mask = new float[floatBuffer.remaining()]; - floatBuffer.get(mask); - return mask; - } - - private static byte[] getBitmapBytes(Bitmap bitmap) { - ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); - bitmap.compress(Bitmap.CompressFormat.PNG, 100, outputStream); - return outputStream.toByteArray(); - } - - @NonNull - private static Map getStringObjectMap(Subject subject) { - Map subjectData = new HashMap<>(); - subjectData.put("startX", subject.getStartX()); - subjectData.put("startY", subject.getStartY()); - subjectData.put("width", subject.getWidth()); - subjectData.put("height", subject.getHeight()); - FloatBuffer confidenceMask = subject.getConfidenceMask(); - if (confidenceMask != null) { - subjectData.put("confidenceMask", getConfidenceMask(confidenceMask)); - } - Bitmap bitmap = subject.getBitmap(); - if (bitmap != null) { - subjectData.put("bitmap", getBitmapBytes(bitmap)); - } - return subjectData; - } - - private void closeDetector(MethodCall call) { - String id = call.argument("id"); - com.google.mlkit.vision.segmentation.subject.SubjectSegmenter subjectSegmenter = instances.get(id); - if (subjectSegmenter == null) return; - subjectSegmenter.close(); - instances.remove(id); - } -} diff --git a/packages/google_mlkit_subject_segmentation/android/src/main/kotlin/com/google_mlkit_subject_segmentation/GoogleMlKitSubjectSegmentationPlugin.kt b/packages/google_mlkit_subject_segmentation/android/src/main/kotlin/com/google_mlkit_subject_segmentation/GoogleMlKitSubjectSegmentationPlugin.kt new file mode 100644 index 00000000..4620679f --- /dev/null +++ b/packages/google_mlkit_subject_segmentation/android/src/main/kotlin/com/google_mlkit_subject_segmentation/GoogleMlKitSubjectSegmentationPlugin.kt @@ -0,0 +1,21 @@ +package com.google_mlkit_subject_segmentation + +import io.flutter.embedding.engine.plugins.FlutterPlugin +import io.flutter.plugin.common.MethodChannel + +class GoogleMlKitSubjectSegmentationPlugin : FlutterPlugin { + private lateinit var channel: MethodChannel + + companion object { + private const val CHANNEL_NAME = "google_mlkit_subject_segmentation" + } + + override fun onAttachedToEngine(binding: FlutterPlugin.FlutterPluginBinding) { + channel = MethodChannel(binding.binaryMessenger, CHANNEL_NAME) + channel.setMethodCallHandler(SubjectSegmenter(binding.applicationContext)) + } + + override fun onDetachedFromEngine(binding: FlutterPlugin.FlutterPluginBinding) { + channel.setMethodCallHandler(null) + } +} diff --git a/packages/google_mlkit_subject_segmentation/android/src/main/kotlin/com/google_mlkit_subject_segmentation/SubjectSegmenter.kt b/packages/google_mlkit_subject_segmentation/android/src/main/kotlin/com/google_mlkit_subject_segmentation/SubjectSegmenter.kt new file mode 100644 index 00000000..948b7eef --- /dev/null +++ b/packages/google_mlkit_subject_segmentation/android/src/main/kotlin/com/google_mlkit_subject_segmentation/SubjectSegmenter.kt @@ -0,0 +1,144 @@ +package com.google_mlkit_subject_segmentation + +import android.content.Context +import android.graphics.Bitmap +import com.google.mlkit.vision.segmentation.subject.Subject +import com.google.mlkit.vision.segmentation.subject.SubjectSegmentation +import com.google.mlkit.vision.segmentation.subject.SubjectSegmentationResult +import com.google.mlkit.vision.segmentation.subject.SubjectSegmenterOptions +import com.google_mlkit_commons.InputImageConverter +import io.flutter.plugin.common.MethodCall +import io.flutter.plugin.common.MethodChannel +import java.io.ByteArrayOutputStream +import java.nio.FloatBuffer + +class SubjectSegmenter( + private val context: Context, +) : MethodChannel.MethodCallHandler { + private val instances = mutableMapOf() + + companion object { + private const val START = "vision#startSubjectSegmenter" + private const val CLOSE = "vision#closeSubjectSegmenter" + + private fun getConfidenceMask(floatBuffer: FloatBuffer): FloatArray { + val mask = FloatArray(floatBuffer.remaining()) + floatBuffer.get(mask) + return mask + } + + private fun getBitmapBytes(bitmap: Bitmap): ByteArray { + val outputStream = ByteArrayOutputStream() + bitmap.compress(Bitmap.CompressFormat.PNG, 100, outputStream) + return outputStream.toByteArray() + } + + private fun getSubjectMap(subject: Subject): Map = + mutableMapOf( + "startX" to subject.startX, + "startY" to subject.startY, + "width" to subject.width, + "height" to subject.height, + ).apply { + subject.confidenceMask?.let { put("confidenceMask", getConfidenceMask(it)) } + subject.bitmap?.let { put("bitmap", getBitmapBytes(it)) } + } + } + + override fun onMethodCall( + call: MethodCall, + result: MethodChannel.Result, + ) { + when (call.method) { + START -> { + handleDetection(call, result) + } + + CLOSE -> { + closeDetector(call) + result.success(null) + } + + else -> { + result.notImplemented() + } + } + } + + private fun handleDetection( + call: MethodCall, + result: MethodChannel.Result, + ) { + val imageData = + call.argument>("imageData") ?: run { + result.error("SubjectSegmenterError", "imageData is null", null) + return + } + val inputImage = InputImageConverter.getInputImageFromData(imageData, context, result) ?: return + + val id = call.argument("id") ?: return + val segmenter = instances.getOrPut(id) { initialize(call) } + + segmenter + .process(inputImage) + .addOnSuccessListener { processResult(it, result) } + .addOnFailureListener { e -> result.error("Subject segmentation failure!", e.message, e) } + } + + private fun initialize(call: MethodCall): com.google.mlkit.vision.segmentation.subject.SubjectSegmenter { + val options = call.argument>("options") ?: emptyMap() + val builder = SubjectSegmenterOptions.Builder() + configureBuilder(builder, options) + return SubjectSegmentation.getClient(builder.build()) + } + + private fun configureBuilder( + builder: SubjectSegmenterOptions.Builder, + options: Map, + ) { + if (options["enableForegroundBitmap"] == true) builder.enableForegroundBitmap() + if (options["enableForegroundConfidenceMask"] == true) builder.enableForegroundConfidenceMask() + + @Suppress("UNCHECKED_CAST") + (options["enableMultiSubjectBitmap"] as? Map)?.let { + configureMultipleSubjects(builder, it) + } + } + + private fun configureMultipleSubjects( + builder: SubjectSegmenterOptions.Builder, + options: Map, + ) { + val enableConfidenceMask = options["enableConfidenceMask"] == true + val enableSubjectBitmap = options["enableSubjectBitmap"] == true + + if (enableConfidenceMask || enableSubjectBitmap) { + val subjectOptionsBuilder = SubjectSegmenterOptions.SubjectResultOptions.Builder() + if (enableConfidenceMask) subjectOptionsBuilder.enableConfidenceMask() + if (enableSubjectBitmap) subjectOptionsBuilder.enableSubjectBitmap() + builder.enableMultipleSubjects(subjectOptionsBuilder.build()) + } + } + + private fun processResult( + segmentationResult: SubjectSegmentationResult, + result: MethodChannel.Result, + ) { + val resultMap = mutableMapOf() + + segmentationResult.foregroundConfidenceMask?.let { + resultMap["foregroundConfidenceMask"] = getConfidenceMask(it) + } + segmentationResult.foregroundBitmap?.let { + resultMap["foregroundBitmap"] = getBitmapBytes(it) + } + resultMap["subjects"] = segmentationResult.subjects.map { getSubjectMap(it) } + + result.success(resultMap) + } + + private fun closeDetector(call: MethodCall) { + val id = call.argument("id") ?: return + instances.remove(id)?.close() + } +} diff --git a/packages/google_mlkit_text_recognition/android/build.gradle b/packages/google_mlkit_text_recognition/android/build.gradle index a3920972..b2fe59d2 100644 --- a/packages/google_mlkit_text_recognition/android/build.gradle +++ b/packages/google_mlkit_text_recognition/android/build.gradle @@ -2,6 +2,7 @@ group = "com.google_mlkit_text_recognition" version = "1.0" buildscript { + ext.kotlin_version = "2.2.20" repositories { google() mavenCentral() @@ -9,6 +10,7 @@ buildscript { dependencies { classpath("com.android.tools.build:gradle:8.13.0") + classpath('org.jetbrains.kotlin:kotlin-gradle-plugin:2.2.20') } } @@ -20,6 +22,7 @@ rootProject.allprojects { } apply plugin: "com.android.library" +apply plugin: "kotlin-android" android { namespace = "com.google_mlkit_text_recognition" @@ -31,6 +34,14 @@ android { targetCompatibility = JavaVersion.VERSION_11 } + kotlinOptions { + jvmTarget = '11' + } + + sourceSets { + main.java.srcDirs += "src/main/kotlin" + } + defaultConfig { minSdk = 21 } diff --git a/packages/google_mlkit_text_recognition/android/gradle/wrapper/gradle-wrapper.properties b/packages/google_mlkit_text_recognition/android/gradle/wrapper/gradle-wrapper.properties new file mode 100644 index 00000000..128196a7 --- /dev/null +++ b/packages/google_mlkit_text_recognition/android/gradle/wrapper/gradle-wrapper.properties @@ -0,0 +1,7 @@ +distributionBase=GRADLE_USER_HOME +distributionPath=wrapper/dists +distributionUrl=https\://services.gradle.org/distributions/gradle-9.0-milestone-1-bin.zip +networkTimeout=10000 +validateDistributionUrl=true +zipStoreBase=GRADLE_USER_HOME +zipStorePath=wrapper/dists diff --git a/packages/google_mlkit_text_recognition/android/src/main/java/com/google_mlkit_text_recognition/GoogleMlKitTextRecognitionPlugin.java b/packages/google_mlkit_text_recognition/android/src/main/java/com/google_mlkit_text_recognition/GoogleMlKitTextRecognitionPlugin.java deleted file mode 100644 index 74e2e0ff..00000000 --- a/packages/google_mlkit_text_recognition/android/src/main/java/com/google_mlkit_text_recognition/GoogleMlKitTextRecognitionPlugin.java +++ /dev/null @@ -1,22 +0,0 @@ -package com.google_mlkit_text_recognition; - -import androidx.annotation.NonNull; - -import io.flutter.embedding.engine.plugins.FlutterPlugin; -import io.flutter.plugin.common.MethodChannel; - -public class GoogleMlKitTextRecognitionPlugin implements FlutterPlugin { - private MethodChannel channel; - private static final String channelName = "google_mlkit_text_recognizer"; - - @Override - public void onAttachedToEngine(@NonNull FlutterPluginBinding flutterPluginBinding) { - channel = new MethodChannel(flutterPluginBinding.getBinaryMessenger(), channelName); - channel.setMethodCallHandler(new TextRecognizer(flutterPluginBinding.getApplicationContext())); - } - - @Override - public void onDetachedFromEngine(@NonNull FlutterPluginBinding binding) { - channel.setMethodCallHandler(null); - } -} diff --git a/packages/google_mlkit_text_recognition/android/src/main/java/com/google_mlkit_text_recognition/TextRecognizer.java b/packages/google_mlkit_text_recognition/android/src/main/java/com/google_mlkit_text_recognition/TextRecognizer.java deleted file mode 100644 index c7b981bc..00000000 --- a/packages/google_mlkit_text_recognition/android/src/main/java/com/google_mlkit_text_recognition/TextRecognizer.java +++ /dev/null @@ -1,210 +0,0 @@ -package com.google_mlkit_text_recognition; - -import android.content.Context; -import android.graphics.Point; -import android.graphics.Rect; - -import androidx.annotation.NonNull; - -import com.google.mlkit.vision.common.InputImage; -import com.google.mlkit.vision.text.Text; -import com.google.mlkit.vision.text.TextRecognition; -import com.google.mlkit.vision.text.chinese.ChineseTextRecognizerOptions; -import com.google.mlkit.vision.text.devanagari.DevanagariTextRecognizerOptions; -import com.google.mlkit.vision.text.japanese.JapaneseTextRecognizerOptions; -import com.google.mlkit.vision.text.korean.KoreanTextRecognizerOptions; -import com.google.mlkit.vision.text.latin.TextRecognizerOptions; -import com.google_mlkit_commons.InputImageConverter; - -import java.util.ArrayList; -import java.util.HashMap; -import java.util.List; -import java.util.Map; - -import io.flutter.plugin.common.MethodCall; -import io.flutter.plugin.common.MethodChannel; - -public class TextRecognizer implements MethodChannel.MethodCallHandler { - private static final String START = "vision#startTextRecognizer"; - private static final String CLOSE = "vision#closeTextRecognizer"; - - private final Context context; - private final Map instances = new HashMap<>(); - - public TextRecognizer(Context context) { - this.context = context; - } - - @Override - public void onMethodCall(@NonNull MethodCall call, @NonNull MethodChannel.Result result) { - String method = call.method; - switch (method) { - case START: - handleDetection(call, result); - break; - case CLOSE: - closeDetector(call); - result.success(null); - break; - default: - result.notImplemented(); - break; - } - } - - private com.google.mlkit.vision.text.TextRecognizer initialize(MethodCall call) { - Integer script = call.argument("script"); - if (script == null) { - return null; - } - switch (script) { - case 0: - return TextRecognition.getClient(TextRecognizerOptions.DEFAULT_OPTIONS); - case 1: - return TextRecognition.getClient(new ChineseTextRecognizerOptions.Builder().build()); - case 2: - return TextRecognition.getClient(new DevanagariTextRecognizerOptions.Builder().build()); - case 3: - return TextRecognition.getClient(new JapaneseTextRecognizerOptions.Builder().build()); - case 4: - return TextRecognition.getClient(new KoreanTextRecognizerOptions.Builder().build()); - default: - return null; - } - } - - private void handleDetection(MethodCall call, final MethodChannel.Result result) { - Map imageData = call.argument("imageData"); - if (imageData == null) { - return; - } - InputImage inputImage = InputImageConverter.getInputImageFromData(imageData, context, result); - if (inputImage == null) return; - - String id = call.argument("id"); - com.google.mlkit.vision.text.TextRecognizer textRecognizer = instances.get(id); - if (textRecognizer == null) { - textRecognizer = initialize(call); - instances.put(id, textRecognizer); - } - if (textRecognizer == null) { - result.error("TextRecognizerError", "TextRecognizer is not initialized", null); - return; - } - textRecognizer.process(inputImage) - .addOnSuccessListener(text -> { - Map textResult = new HashMap<>(); - - textResult.put("text", text.getText()); - - List> textBlocks = new ArrayList<>(); - for (Text.TextBlock block : text.getTextBlocks()) { - Map blockData = new HashMap<>(); - - addData(blockData, - block.getText(), - block.getBoundingBox(), - block.getCornerPoints(), - block.getRecognizedLanguage(), - null, - null); - - List> textLines = new ArrayList<>(); - for (Text.Line line : block.getLines()) { - Map lineData = new HashMap<>(); - - addData(lineData, - line.getText(), - line.getBoundingBox(), - line.getCornerPoints(), - line.getRecognizedLanguage(), - line.getConfidence(), - line.getAngle()); - - List> elementsData = new ArrayList<>(); - for (Text.Element element : line.getElements()) { - Map elementData = new HashMap<>(); - - addData(elementData, - element.getText(), - element.getBoundingBox(), - element.getCornerPoints(), - element.getRecognizedLanguage(), - element.getConfidence(), - element.getAngle()); - - List> symbolsData = new ArrayList<>(); - for (Text.Symbol symbol : element.getSymbols()) { - Map symbolData = new HashMap<>(); - - addData(symbolData, - symbol.getText(), - symbol.getBoundingBox(), - symbol.getCornerPoints(), - symbol.getRecognizedLanguage(), - symbol.getConfidence(), - symbol.getAngle()); - symbolsData.add(symbolData); - } - - elementData.put("symbols", symbolsData); - elementsData.add(elementData); - } - lineData.put("elements", elementsData); - textLines.add(lineData); - } - blockData.put("lines", textLines); - textBlocks.add(blockData); - } - textResult.put("blocks", textBlocks); - result.success(textResult); - }) - .addOnFailureListener(e -> result.error("TextRecognizerError", e.toString(), null)); - } - - private void addData(Map addTo, - String text, - Rect rect, - Point[] cornerPoints, - String recognizedLanguage, - Float confidence, - Float angle - ) { - List recognizedLanguages = new ArrayList<>(); - recognizedLanguages.add(recognizedLanguage); - List> points = new ArrayList<>(); - addPoints(cornerPoints, points); - addTo.put("points", points); - addTo.put("rect", getBoundingPoints(rect)); - addTo.put("recognizedLanguages", recognizedLanguages); - addTo.put("text", text); - addTo.put("confidence", confidence); - addTo.put("angle", angle); - } - - private void addPoints(Point[] cornerPoints, List> points) { - for (Point point : cornerPoints) { - Map p = new HashMap<>(); - p.put("x", point.x); - p.put("y", point.y); - points.add(p); - } - } - - private Map getBoundingPoints(Rect rect) { - Map frame = new HashMap<>(); - frame.put("left", rect.left); - frame.put("right", rect.right); - frame.put("top", rect.top); - frame.put("bottom", rect.bottom); - return frame; - } - - private void closeDetector(MethodCall call) { - String id = call.argument("id"); - com.google.mlkit.vision.text.TextRecognizer textRecognizer = instances.get(id); - if (textRecognizer == null) return; - textRecognizer.close(); - instances.remove(id); - } -} diff --git a/packages/google_mlkit_text_recognition/android/src/main/kotlin/com/google_mlkit_text_recognition/GoogleMlKitTextRecognitionPlugin.kt b/packages/google_mlkit_text_recognition/android/src/main/kotlin/com/google_mlkit_text_recognition/GoogleMlKitTextRecognitionPlugin.kt new file mode 100644 index 00000000..df5c708b --- /dev/null +++ b/packages/google_mlkit_text_recognition/android/src/main/kotlin/com/google_mlkit_text_recognition/GoogleMlKitTextRecognitionPlugin.kt @@ -0,0 +1,21 @@ +package com.google_mlkit_text_recognition + +import io.flutter.embedding.engine.plugins.FlutterPlugin +import io.flutter.plugin.common.MethodChannel + +class GoogleMlKitTextRecognitionPlugin : FlutterPlugin { + private lateinit var channel: MethodChannel + + companion object { + private const val CHANNEL_NAME = "google_mlkit_text_recognizer" + } + + override fun onAttachedToEngine(binding: FlutterPlugin.FlutterPluginBinding) { + channel = MethodChannel(binding.binaryMessenger, CHANNEL_NAME) + channel.setMethodCallHandler(TextRecognizer(binding.applicationContext)) + } + + override fun onDetachedFromEngine(binding: FlutterPlugin.FlutterPluginBinding) { + channel.setMethodCallHandler(null) + } +} diff --git a/packages/google_mlkit_text_recognition/android/src/main/kotlin/com/google_mlkit_text_recognition/TextRecognizer.kt b/packages/google_mlkit_text_recognition/android/src/main/kotlin/com/google_mlkit_text_recognition/TextRecognizer.kt new file mode 100644 index 00000000..6de10613 --- /dev/null +++ b/packages/google_mlkit_text_recognition/android/src/main/kotlin/com/google_mlkit_text_recognition/TextRecognizer.kt @@ -0,0 +1,168 @@ +package com.google_mlkit_text_recognition + +import android.content.Context +import android.graphics.Point +import android.graphics.Rect +import com.google.mlkit.vision.text.TextRecognition +import com.google.mlkit.vision.text.chinese.ChineseTextRecognizerOptions +import com.google.mlkit.vision.text.devanagari.DevanagariTextRecognizerOptions +import com.google.mlkit.vision.text.japanese.JapaneseTextRecognizerOptions +import com.google.mlkit.vision.text.korean.KoreanTextRecognizerOptions +import com.google.mlkit.vision.text.latin.TextRecognizerOptions +import com.google_mlkit_commons.InputImageConverter +import io.flutter.plugin.common.MethodCall +import io.flutter.plugin.common.MethodChannel + +class TextRecognizer( + private val context: Context, +) : MethodChannel.MethodCallHandler { + private val instances = mutableMapOf() + + companion object { + private const val START = "vision#startTextRecognizer" + private const val CLOSE = "vision#closeTextRecognizer" + } + + override fun onMethodCall( + call: MethodCall, + result: MethodChannel.Result, + ) { + when (call.method) { + START -> { + handleDetection(call, result) + } + + CLOSE -> { + closeDetector(call) + result.success(null) + } + + else -> { + result.notImplemented() + } + } + } + + private fun initialize(call: MethodCall): com.google.mlkit.vision.text.TextRecognizer? = + when (call.argument("script")) { + 0 -> TextRecognition.getClient(TextRecognizerOptions.DEFAULT_OPTIONS) + 1 -> TextRecognition.getClient(ChineseTextRecognizerOptions.Builder().build()) + 2 -> TextRecognition.getClient(DevanagariTextRecognizerOptions.Builder().build()) + 3 -> TextRecognition.getClient(JapaneseTextRecognizerOptions.Builder().build()) + 4 -> TextRecognition.getClient(KoreanTextRecognizerOptions.Builder().build()) + else -> null + } + + private fun handleDetection( + call: MethodCall, + result: MethodChannel.Result, + ) { + val imageData = call.argument>("imageData") ?: return + val inputImage = InputImageConverter.getInputImageFromData(imageData, context, result) ?: return + + val id = call.argument("id") ?: return + val textRecognizer = + instances.getOrPut(id) { + initialize(call) ?: run { + result.error("TextRecognizerError", "TextRecognizer is not initialized", null) + return + } + } + + textRecognizer + .process(inputImage) + .addOnSuccessListener { text -> + val textResult = + mutableMapOf( + "text" to text.text, + "blocks" to + text.textBlocks.map { block -> + mutableMapOf().apply { + addData(this, block.text, block.boundingBox, block.cornerPoints, block.recognizedLanguage, null, null) + put( + "lines", + block.lines.map { line -> + mutableMapOf().apply { + addData( + this, + line.text, + line.boundingBox, + line.cornerPoints, + line.recognizedLanguage, + line.confidence, + line.angle, + ) + put( + "elements", + line.elements.map { element -> + mutableMapOf().apply { + addData( + this, + element.text, + element.boundingBox, + element.cornerPoints, + element.recognizedLanguage, + element.confidence, + element.angle, + ) + put( + "symbols", + element.symbols.map { symbol -> + mutableMapOf().apply { + addData( + this, + symbol.text, + symbol.boundingBox, + symbol.cornerPoints, + symbol.recognizedLanguage, + symbol.confidence, + symbol.angle, + ) + } + }, + ) + } + }, + ) + } + }, + ) + } + }, + ) + result.success(textResult) + }.addOnFailureListener { e -> + result.error("TextRecognizerError", e.toString(), null) + } + } + + private fun addData( + addTo: MutableMap, + text: String, + rect: Rect?, + cornerPoints: Array?, + recognizedLanguage: String, + confidence: Float?, + angle: Float?, + ) { + addTo["text"] = text + addTo["rect"] = rect?.let { getBoundingPoints(it) } + addTo["points"] = cornerPoints?.map { mapOf("x" to it.x, "y" to it.y) } ?: emptyList>() + addTo["recognizedLanguages"] = listOf(recognizedLanguage) + addTo["confidence"] = confidence + addTo["angle"] = angle + } + + private fun getBoundingPoints(rect: Rect) = + mapOf( + "left" to rect.left, + "right" to rect.right, + "top" to rect.top, + "bottom" to rect.bottom, + ) + + private fun closeDetector(call: MethodCall) { + val id = call.argument("id") ?: return + instances.remove(id)?.close() + } +} diff --git a/packages/google_mlkit_translation/android/build.gradle b/packages/google_mlkit_translation/android/build.gradle index ce0c92a7..97b77233 100644 --- a/packages/google_mlkit_translation/android/build.gradle +++ b/packages/google_mlkit_translation/android/build.gradle @@ -2,6 +2,7 @@ group = "com.google_mlkit_translation" version = "1.0" buildscript { + ext.kotlin_version = "2.2.20" repositories { google() mavenCentral() @@ -9,6 +10,7 @@ buildscript { dependencies { classpath("com.android.tools.build:gradle:8.13.0") + classpath('org.jetbrains.kotlin:kotlin-gradle-plugin:2.2.20') } } @@ -20,6 +22,7 @@ rootProject.allprojects { } apply plugin: "com.android.library" +apply plugin: "kotlin-android" android { namespace = "com.google_mlkit_translation" @@ -31,6 +34,15 @@ android { targetCompatibility = JavaVersion.VERSION_11 } + kotlinOptions { + jvmTarget = '11' + } + + sourceSets { + main.java.srcDirs += "src/main/kotlin" + } + + defaultConfig { minSdk = 21 } diff --git a/packages/google_mlkit_translation/android/gradle/wrapper/gradle-wrapper.properties b/packages/google_mlkit_translation/android/gradle/wrapper/gradle-wrapper.properties new file mode 100644 index 00000000..128196a7 --- /dev/null +++ b/packages/google_mlkit_translation/android/gradle/wrapper/gradle-wrapper.properties @@ -0,0 +1,7 @@ +distributionBase=GRADLE_USER_HOME +distributionPath=wrapper/dists +distributionUrl=https\://services.gradle.org/distributions/gradle-9.0-milestone-1-bin.zip +networkTimeout=10000 +validateDistributionUrl=true +zipStoreBase=GRADLE_USER_HOME +zipStorePath=wrapper/dists diff --git a/packages/google_mlkit_translation/android/src/main/java/com/google_mlkit_translation/GoogleMlKitTranslationPlugin.java b/packages/google_mlkit_translation/android/src/main/java/com/google_mlkit_translation/GoogleMlKitTranslationPlugin.java deleted file mode 100644 index 32ea27bb..00000000 --- a/packages/google_mlkit_translation/android/src/main/java/com/google_mlkit_translation/GoogleMlKitTranslationPlugin.java +++ /dev/null @@ -1,22 +0,0 @@ -package com.google_mlkit_translation; - -import androidx.annotation.NonNull; - -import io.flutter.embedding.engine.plugins.FlutterPlugin; -import io.flutter.plugin.common.MethodChannel; - -public class GoogleMlKitTranslationPlugin implements FlutterPlugin { - private MethodChannel channel; - private static final String channelName = "google_mlkit_on_device_translator"; - - @Override - public void onAttachedToEngine(@NonNull FlutterPluginBinding flutterPluginBinding) { - channel = new MethodChannel(flutterPluginBinding.getBinaryMessenger(), channelName); - channel.setMethodCallHandler(new TextTranslator()); - } - - @Override - public void onDetachedFromEngine(@NonNull FlutterPluginBinding binding) { - channel.setMethodCallHandler(null); - } -} diff --git a/packages/google_mlkit_translation/android/src/main/java/com/google_mlkit_translation/TextTranslator.java b/packages/google_mlkit_translation/android/src/main/java/com/google_mlkit_translation/TextTranslator.java deleted file mode 100644 index 42c5561a..00000000 --- a/packages/google_mlkit_translation/android/src/main/java/com/google_mlkit_translation/TextTranslator.java +++ /dev/null @@ -1,90 +0,0 @@ -package com.google_mlkit_translation; - -import androidx.annotation.NonNull; - -import com.google.mlkit.nl.translate.TranslateRemoteModel; -import com.google.mlkit.nl.translate.Translation; -import com.google.mlkit.nl.translate.Translator; -import com.google.mlkit.nl.translate.TranslatorOptions; -import com.google_mlkit_commons.GenericModelManager; - -import java.util.HashMap; -import java.util.Map; - -import io.flutter.plugin.common.MethodCall; -import io.flutter.plugin.common.MethodChannel; - -public class TextTranslator implements MethodChannel.MethodCallHandler { - private static final String START = "nlp#startLanguageTranslator"; - private static final String CLOSE = "nlp#closeLanguageTranslator"; - private static final String MANAGE = "nlp#manageLanguageModelModels"; - - private final Map instances = new HashMap<>(); - private final GenericModelManager genericModelManager = new GenericModelManager(); - - @Override - public void onMethodCall(@NonNull MethodCall call, @NonNull MethodChannel.Result result) { - String method = call.method; - switch (method) { - case START: - translateText(call, result); - break; - case CLOSE: - closeDetector(call); - result.success(null); - break; - case MANAGE: - manageModel(call, result); - break; - default: - result.notImplemented(); - break; - } - } - - private void translateText(MethodCall call, final MethodChannel.Result result) { - String text = call.argument("text"); - - String id = call.argument("id"); - Translator onDeviceTranslator = instances.get(id); - if (onDeviceTranslator == null) { - String sourceLanguage = call.argument("source"); - String targetLanguage = call.argument("target"); - TranslatorOptions options = new TranslatorOptions.Builder() - .setSourceLanguage(sourceLanguage) - .setTargetLanguage(targetLanguage) - .build(); - onDeviceTranslator = Translation.getClient(options); - instances.put(id, onDeviceTranslator); - } - final Translator translator = onDeviceTranslator; - - translator.downloadModelIfNeeded() - .addOnSuccessListener( - (OnSuccessListener) -> { - // Model downloaded successfully. Okay to start translating. - translator.translate(text) - .addOnSuccessListener(result::success) - .addOnFailureListener( - e -> result.error("error translating", e.toString(), null)); - }) - .addOnFailureListener( - e -> { - // Model could not be downloaded or other internal error. - result.error("Error building translator", "Either source or target models not downloaded", null); - }); - } - - private void closeDetector(MethodCall call) { - String id = call.argument("id"); - Translator translator = instances.get(id); - if (translator == null) return; - translator.close(); - instances.remove(id); - } - - private void manageModel(MethodCall call, final MethodChannel.Result result) { - TranslateRemoteModel model = new TranslateRemoteModel.Builder(call.argument("model")).build(); - genericModelManager.manageModel(model, call, result); - } -} diff --git a/packages/google_mlkit_translation/android/src/main/kotlin/com/google_mlkit_translation/GoogleMlKitTranslationPlugin.kt b/packages/google_mlkit_translation/android/src/main/kotlin/com/google_mlkit_translation/GoogleMlKitTranslationPlugin.kt new file mode 100644 index 00000000..a68d7d27 --- /dev/null +++ b/packages/google_mlkit_translation/android/src/main/kotlin/com/google_mlkit_translation/GoogleMlKitTranslationPlugin.kt @@ -0,0 +1,21 @@ +package com.google_mlkit_translation + +import io.flutter.embedding.engine.plugins.FlutterPlugin +import io.flutter.plugin.common.MethodChannel + +class GoogleMlKitTranslationPlugin : FlutterPlugin { + private lateinit var channel: MethodChannel + + companion object { + private const val CHANNEL_NAME = "google_mlkit_on_device_translator" + } + + override fun onAttachedToEngine(binding: FlutterPlugin.FlutterPluginBinding) { + channel = MethodChannel(binding.binaryMessenger, CHANNEL_NAME) + channel.setMethodCallHandler(TextTranslator()) + } + + override fun onDetachedFromEngine(binding: FlutterPlugin.FlutterPluginBinding) { + channel.setMethodCallHandler(null) + } +} diff --git a/packages/google_mlkit_translation/android/src/main/kotlin/com/google_mlkit_translation/TextTranslator.kt b/packages/google_mlkit_translation/android/src/main/kotlin/com/google_mlkit_translation/TextTranslator.kt new file mode 100644 index 00000000..a1e82b01 --- /dev/null +++ b/packages/google_mlkit_translation/android/src/main/kotlin/com/google_mlkit_translation/TextTranslator.kt @@ -0,0 +1,89 @@ +package com.google_mlkit_translation + +import com.google.mlkit.nl.translate.TranslateRemoteModel +import com.google.mlkit.nl.translate.Translation +import com.google.mlkit.nl.translate.Translator +import com.google.mlkit.nl.translate.TranslatorOptions +import com.google_mlkit_commons.GenericModelManager +import io.flutter.plugin.common.MethodCall +import io.flutter.plugin.common.MethodChannel + +class TextTranslator : MethodChannel.MethodCallHandler { + private val instances = mutableMapOf() + private val genericModelManager = GenericModelManager() + + companion object { + private const val START = "nlp#startLanguageTranslator" + private const val CLOSE = "nlp#closeLanguageTranslator" + private const val MANAGE = "nlp#manageLanguageModelModels" + } + + override fun onMethodCall( + call: MethodCall, + result: MethodChannel.Result, + ) { + when (call.method) { + START -> { + translateText(call, result) + } + + CLOSE -> { + closeDetector(call) + result.success(null) + } + + MANAGE -> { + manageModel(call, result) + } + + else -> { + result.notImplemented() + } + } + } + + private fun translateText( + call: MethodCall, + result: MethodChannel.Result, + ) { + val text = call.argument("text") ?: return + val id = call.argument("id") ?: return + + val translator = + instances.getOrPut(id) { + val sourceLanguage = call.argument("source") ?: return + val targetLanguage = call.argument("target") ?: return + Translation.getClient( + TranslatorOptions + .Builder() + .setSourceLanguage(sourceLanguage) + .setTargetLanguage(targetLanguage) + .build(), + ) + } + + translator + .downloadModelIfNeeded() + .addOnSuccessListener { + translator + .translate(text) + .addOnSuccessListener { result.success(it) } + .addOnFailureListener { e -> result.error("error translating", e.toString(), null) } + }.addOnFailureListener { + result.error("Error building translator", "Either source or target models not downloaded", null) + } + } + + private fun closeDetector(call: MethodCall) { + val id = call.argument("id") ?: return + instances.remove(id)?.close() + } + + private fun manageModel( + call: MethodCall, + result: MethodChannel.Result, + ) { + val model = TranslateRemoteModel.Builder(call.argument("model").toString()).build() + genericModelManager.manageModel(model, call, result) + } +}