import gradio as gr
get_point_mask = """
async function getPointMask(image, points) {
console.log("getting point mask");
//console.log(image, points)
const { maskURL } = await segmentPoints(
image,
points
);
if(points.length == 0){
return [ null ];
}
return [ maskURL ];
}
"""
def set_points(image, points_state, evt: gr.SelectData):
points_state.append([evt.index[0]/image.width, evt.index[1]/image.height, True])
return points_state, points_state
with gr.Blocks() as demo:
gr.Markdown("""## Segment Anything Model (SAM) with Gradio Lite
This demo uses [Gradio Lite](https://www.gradio.app/guides/gradio-lite) as UI for running the Segment Anything Model (SAM) with WASM build with [Candle](https://github.com/huggingface/candle).
**Note:** The model's first run may take a few seconds as it loads and caches the model in the browser, and then creates the image embeddings. Any subsequent clicks on points will be significantly faster.
""")
points_state = gr.State([])
with gr.Row():
with gr.Column():
image = gr.Image(label="Input Image", type="pil")
clear_points = gr.Button(value="Clear Points")
points = gr.JSON(label="Input Points", visible=False)
with gr.Column():
mask = gr.Image(label="Output Mask")
clear_points.click(lambda: ([], []), None, [points, points_state])
image.select(set_points, inputs=[image, points_state], outputs=[points, points_state])
points.change(None, inputs=[image, points], outputs=[mask], _js=get_point_mask)
demo.launch(show_api=False)