Skip to content

Commit f7519ee

Browse files
authored
Add Segment Any Text (SaT) model (#7)
* segment any text model * Update README to mention PyTorch version * Improve formatting in README.md for commands Updated README formatting for clarity. * add hugging face link to download model * changed directory location
1 parent db44950 commit f7519ee

6 files changed

Lines changed: 1665 additions & 0 deletions

File tree

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
# Segment Any Text CoreML
2+
Segment Any Text is state-of-the-art sentence segmentation with 3 Transfomer layers. A pytorch version of the model is used in [wtsplit](https://github.com/segment-any-text/wtpsplit) and additional details can be found in this [paper](https://arxiv.org/abs/2406.16678).
3+
4+
If you wish to skip the CoreML conversion, you can download a precompiled `SaT.mlmodelc` from [Hugging Face](https://huggingface.co/smdesai/SaT).
5+
6+
7+
# CoreML Conversion
8+
9+
## Environment Setup
10+
11+
1. Install [uv](https://github.com/astral-sh/uv) if it is not already available.
12+
2. Sync the project environment.
13+
```bash
14+
uv sync
15+
```
16+
3. Activate the virtual environment:
17+
```bash
18+
source .venv/bin/activate
19+
```
20+
21+
## Converting the Model
22+
23+
Run the conversion script to create the SaT Core ML package:
24+
25+
```bash
26+
python convert-sat.py --model-id segment-any-text/sat-3l-sm --output-dir sat_coreml
27+
```
28+
29+
This produces `SaT.mlpackage` in the `sat_coreml` directory.
30+
31+
Here is the complete usage:
32+
```bash
33+
Usage: convert_sat.py [OPTIONS]
34+
35+
Options
36+
--model-id TEXT Model identifier to download
37+
from HuggingFace model hub
38+
[default:
39+
segment-any-text/sat-3l-sm]
40+
--output-dir PATH Directory to write mlpackage and
41+
[default: sat_coreml]
42+
--conversion-type -c TEXT Conversion methods to apply to
43+
the model. Repeat the option to
44+
chain conversions (allowed:
45+
none, prune, quantize,
46+
palettize; default: none).
47+
[default: None]
48+
```
49+
50+
## Compiling the Model
51+
52+
Run the following to compile the model.
53+
```bash
54+
python compile_mlmodelc.py --coreml-dir sat_coreml
55+
```
56+
57+
This produces `SaT.mlmodelc` in the `compiled` directory.
58+
59+
Here is the complete usage:
60+
```bash
61+
Usage: compile_mlmodelc.py [OPTIONS]
62+
63+
Options
64+
--coreml-dir PATH Directory where mlpackages and metadata are written
65+
[default: sat_coreml]
66+
```
Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
from __future__ import annotations
2+
3+
import shutil
4+
import subprocess
5+
import sys
6+
import typer
7+
from pathlib import Path
8+
9+
BASE_DIR = Path(__file__).resolve().parent
10+
OUTPUT_ROOT = BASE_DIR / "compiled"
11+
12+
app = typer.Typer(add_completion=False, pretty_exceptions_show_locals=False)
13+
14+
15+
def ensure_coremlcompiler() -> None:
16+
"""Ensure ``xcrun coremlcompiler`` is available for the active Xcode."""
17+
xcrun_path = shutil.which("xcrun")
18+
if xcrun_path is None:
19+
print("Error: 'xcrun' not found on PATH. Install Xcode command line tools.", file=sys.stderr)
20+
sys.exit(1)
21+
22+
try:
23+
subprocess.run([
24+
xcrun_path,
25+
"--find",
26+
"coremlcompiler",
27+
], check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
28+
except subprocess.CalledProcessError:
29+
print("Error: 'coremlcompiler' not found via xcrun. Check your Xcode installation.", file=sys.stderr)
30+
sys.exit(1)
31+
32+
33+
def gather_packages(dir: str) -> list[Path]:
34+
"""Return a list of all ``*.mlpackage`` bundles under the source dirs."""
35+
packages: list[Path] = []
36+
source = BASE_DIR / dir
37+
if not source.exists():
38+
print(f"Warning: {source.relative_to(BASE_DIR)} does not exist; skipping", file=sys.stderr)
39+
return packages
40+
packages.extend(source.rglob("*.mlpackage"))
41+
return packages
42+
43+
44+
def compile_package(package: Path) -> None:
45+
"""Compile a single ``.mlpackage`` bundle using ``xcrun coremlcompiler``."""
46+
relative_pkg = package.relative_to(BASE_DIR)
47+
#output_dir = OUTPUT_ROOT / relative_pkg.parent
48+
output_dir = OUTPUT_ROOT
49+
output_dir.mkdir(parents=True, exist_ok=True)
50+
output_path = output_dir / f"{package.stem}.mlmodelc"
51+
52+
if output_path.exists():
53+
shutil.rmtree(output_path)
54+
55+
cmd = [
56+
"xcrun",
57+
"coremlcompiler",
58+
"compile",
59+
str(package),
60+
str(output_dir),
61+
]
62+
63+
print(f"Compiling {relative_pkg} -> {output_path.relative_to(BASE_DIR)}")
64+
subprocess.run(cmd, check=True)
65+
66+
67+
@app.command()
68+
def compile(
69+
coreml_dir: Path = typer.Option(
70+
Path("sat_coreml"),
71+
help="Directory where mlpackages and metadata are written",
72+
),
73+
):
74+
ensure_coremlcompiler()
75+
packages = gather_packages(coreml_dir)
76+
77+
if not packages:
78+
print("No .mlpackage bundles found to compile.")
79+
return
80+
81+
for package in packages:
82+
try:
83+
compile_package(package)
84+
except subprocess.CalledProcessError as exc:
85+
print(f"Failed to compile {package}: {exc}", file=sys.stderr)
86+
sys.exit(exc.returncode)
87+
88+
print(f"Finished compiling {len(packages)} package(s) into {OUTPUT_ROOT.relative_to(BASE_DIR)}.")
89+
90+
91+
if __name__ == "__main__":
92+
app()
Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
import json
2+
import os
3+
from enum import IntEnum
4+
5+
import coremltools.optimize.coreml as cto_coreml
6+
7+
8+
class Conversion(IntEnum):
9+
NONE = 0
10+
PRUNE = 1
11+
QUANTIZE = 2
12+
PALETTIZE = 3
13+
14+
15+
def update_manifest_model_name(manifest_path: str, new_name: str) -> None:
16+
with open(manifest_path, "r") as file:
17+
manifest = json.load(file)
18+
19+
for key, value in manifest["itemInfoEntries"].items():
20+
if value["name"] == "model.mlmodel":
21+
value["name"] = f"{new_name}.mlmodel"
22+
value["path"] = f"com.apple.CoreML/{new_name}.mlmodel"
23+
24+
with open(manifest_path, "w") as file:
25+
json.dump(manifest, file, indent=4)
26+
27+
print(f"Manifest updated. Model name changed to {new_name}.mlmodel")
28+
29+
old_model_path = os.path.join(
30+
os.path.dirname(manifest_path), "Data/com.apple.CoreML/model.mlmodel"
31+
)
32+
new_model_path = os.path.join(
33+
os.path.dirname(manifest_path),
34+
f"Data/com.apple.CoreML/{new_name}.mlmodel",
35+
)
36+
if os.path.exists(old_model_path):
37+
os.rename(old_model_path, new_model_path)
38+
print(f"Model file renamed from model.mlmodel to {new_name}.mlmodel")
39+
else:
40+
print("Warning: model.mlmodel not found. Only manifest was updated.")
41+
42+
43+
def palettize_model(mlpackage, *, bits: int = 8, weight_threshold: int = 512):
44+
print(f"\nApplying {bits}-bit palettization...")
45+
try:
46+
op_config = cto_coreml.OpPalettizerConfig(
47+
nbits=bits,
48+
weight_threshold=weight_threshold,
49+
)
50+
config = cto_coreml.OptimizationConfig(op_config)
51+
return cto_coreml.palettize_weights(mlpackage, config)
52+
except Exception as e:
53+
print(f"Error palettization failed: {e}")
54+
return None
55+
56+
57+
def prune_model(mlpackage, *, threshold: float = 0.01):
58+
print(f"\nApplying pruning quantization...")
59+
try:
60+
config = cto_coreml.OptimizationConfig(
61+
global_config=cto_coreml.OpThresholdPrunerConfig(threshold=threshold)
62+
)
63+
return cto_coreml.prune_weights(mlpackage, config)
64+
except Exception as e:
65+
print(f"Error pruning failed: {e}")
66+
return None
67+
68+
69+
def quantize_model(mlpackage, *, dtype: str = "int8", mode: str = "linear_symmetric"):
70+
if str == "linear":
71+
print(f"\nApplying {dtype} quantization...")
72+
else:
73+
print("\nApplying mixed precision quantization...")
74+
75+
try:
76+
op_config = cto_coreml.OpLinearQuantizerConfig(
77+
mode=mode,
78+
dtype=dtype,
79+
granularity="per_block",
80+
block_size=32,
81+
)
82+
83+
config = cto_coreml.OptimizationConfig(global_config=op_config)
84+
return cto_coreml.linear_quantize_weights(mlpackage, config)
85+
except Exception as e:
86+
print(f"INT8 quantization failed: {e}")
87+
return None
88+
89+
90+
def apply_conversion(mlpackage, conversion_type: Conversion):
91+
match conversion_type:
92+
case Conversion.NONE:
93+
return mlpackage
94+
case Conversion.PRUNE:
95+
return prune_model(mlpackage)
96+
case Conversion.QUANTIZE:
97+
return quantize_model(mlpackage)
98+
case Conversion.PALETTIZE:
99+
return palettize_model(mlpackage)
100+
case _:
101+
raise ValueError(f"Unsupported conversion type: {conversion_type}")
Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
from __future__ import annotations
2+
3+
import os
4+
from pathlib import Path
5+
6+
import coremltools as ct
7+
import numpy as np
8+
import torch
9+
import typer
10+
11+
from transformers import AutoModelForTokenClassification, AutoTokenizer
12+
import wtpsplit.models # registers SubwordXLM config/model types
13+
14+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
15+
16+
DEFAULT_MODEL_ID = "segment-any-text/sat-3l-sm"
17+
18+
19+
from conversion_utils import (
20+
Conversion,
21+
apply_conversion,
22+
update_manifest_model_name,
23+
)
24+
25+
app = typer.Typer(add_completion=False, pretty_exceptions_show_locals=False)
26+
27+
28+
def parse_conversion_type(value: str | None) -> Conversion:
29+
if value is None:
30+
return Conversion.NONE
31+
value_str = value.strip()
32+
if not value_str:
33+
return Conversion.NONE
34+
35+
try:
36+
return Conversion[value_str.upper()]
37+
except KeyError as exc:
38+
raise typer.BadParameter(
39+
f"Invalid conversion type '{value}'. "
40+
"Choose from 'none', 'prune', 'quantize', or 'palettize'."
41+
) from exc
42+
43+
44+
def parse_conversion_types(
45+
values: tuple[str, ...] | list[str] | None,
46+
) -> list[Conversion]:
47+
if not values:
48+
return [Conversion.NONE]
49+
50+
parsed: list[Conversion] = []
51+
for item in values:
52+
parsed.append(parse_conversion_type(item))
53+
return parsed
54+
55+
56+
@app.command()
57+
def convert(
58+
model_id: str = typer.Option(
59+
DEFAULT_MODEL_ID,
60+
"--model-id",
61+
help="Model identifier to download from HuggingFace's model hub",
62+
),
63+
output_dir: Path = typer.Option(
64+
Path("sat_coreml"),
65+
help="Directory where mlpackages and metadata will be written",
66+
),
67+
conversion_types: list[str] = typer.Option(
68+
None,
69+
"--conversion-type",
70+
"-c",
71+
help=(
72+
"Conversion methods to apply to the model. "
73+
"Repeat the option to chain conversions "
74+
"(allowed: none, prune, quantize, palettize; default: none)."
75+
),
76+
),
77+
):
78+
79+
conversions_to_apply = parse_conversion_types(conversion_types)
80+
81+
model = AutoModelForTokenClassification.from_pretrained(
82+
model_id,
83+
return_dict=False,
84+
torchscript=True,
85+
trust_remote_code=True,
86+
).eval()
87+
88+
tokenizer = AutoTokenizer.from_pretrained("facebookAI/xlm-roberta-base")
89+
tokenized = tokenizer(
90+
["Sample input text to trace the model."],
91+
return_tensors="pt",
92+
max_length=512, # token sequence length
93+
padding="max_length",
94+
)
95+
96+
traced_model = torch.jit.trace(
97+
model,
98+
(tokenized["input_ids"], tokenized["attention_mask"])
99+
)
100+
101+
outputs = [ct.TensorType(name="output")]
102+
103+
mlpackage = ct.convert(
104+
traced_model,
105+
convert_to="mlprogram",
106+
inputs=[
107+
ct.TensorType(
108+
f"{name}",
109+
shape=tensor.shape,
110+
dtype=np.int32,
111+
)
112+
for name, tensor in tokenized.items()
113+
],
114+
outputs=outputs,
115+
compute_units=ct.ComputeUnit.ALL,
116+
minimum_deployment_target=ct.target.iOS18,
117+
)
118+
119+
try:
120+
new_model = mlpackage
121+
for conversion in conversions_to_apply:
122+
new_model = apply_conversion(new_model, conversion)
123+
except ValueError as e:
124+
print(e)
125+
return
126+
127+
saved_name = "SaT"
128+
saved_path = output_dir / f"{saved_name}.mlpackage"
129+
new_model.save(saved_path)
130+
131+
manifest_file = saved_path / "Manifest.json"
132+
update_manifest_model_name(manifest_file, saved_name)
133+
134+
if __name__ == "__main__":
135+
app()

0 commit comments

Comments
 (0)