import sys
import subprocess
# implement pip as a subprocess:
subprocess.check_call([sys.executable, '-m', 'pip', 'install', 'streamlit'])
import streamlit as st
import torch
from cmn import plots as z, model_analysis
import numpy as np
import cmn
from cmn.classification import LitClassifier
from cmn.training import LitYolo
from cmn.runs import Run
from collections import Counter
MODEL_NAME = 'serious-hound-714'
CLASSIFIER_NAME = 'stylish-ram-967'
VAL_CONFIDENCE = 0.35
BORDER = 0.0
DEVICE = 'cpu'
@st.cache_resource
def load_model():
run = Run(folder='/home/mechi/cumana/mlruns', model_name=MODEL_NAME)
return LitYolo.load_from_checkpoint(
run.checkpoint,
model=torch.load(f'{MODEL_NAME}.pkl').eval()
).to(DEVICE).eval().model
@st.cache_resource
def load_classifier():
run = Run(folder='/home/mechi/cumana/mlruns', model_name=CLASSIFIER_NAME)
return LitClassifier.load_from_checkpoint(run.checkpoint).to(DEVICE).eval()
clasificador = load_classifier()
def box_predictor(model, confidence, device='cpu', nms_th=0.6):
model = model.to(device)
@torch.no_grad()
def predictor(img):
xs = torch.as_tensor(np.stack([np.array(img | z.smallest_size_to(22 * 32))])).to(device)
y_hat = model(xs).cpu()
all_boxes = cmn.model_analysis.get_boxes(y_hat, confidence, torch.sigmoid)
return cmn.model_analysis.non_maximum_suppression(all_boxes, nms_th)[0]
return predictor
@torch.no_grad()
def predict_boxes(crops, img, border=0.15):
arrimg = np.stack([e.crop(img, border).resize((224, 224)) for e in crops])
crops = torch.as_tensor(arrimg, device=clasificador.device).permute(0, 3, 1, 2)
res = torch.softmax(clasificador.model(crops/256.0), dim=1).max(axis=1)
return res.values.cpu().numpy(), res.indices.cpu().numpy()
##################################################################
# UI
##################################################################
st.sidebar.markdown(f"**Current Model:** {MODEL_NAME}")
confidence = st.sidebar.slider("Confidence", min_value=0.0, max_value=1.0, value=VAL_CONFIDENCE, step=0.01)
show_nmn = st.sidebar.checkbox('Mostrar Suprimidas')
image_files = st.file_uploader("Upload your file here...", accept_multiple_files=True)
camera = st.camera_input("Take a photo")
for image_file in image_files or (camera and [camera]) or []:
img = z.to_image(image_file)
predictor = box_predictor(load_model(), confidence=confidence, device=DEVICE)
pred_boxes = predictor(img.copy())
proba, classes = predict_boxes(pred_boxes, img, BORDER)
pred_boxes = [b.assign(cat=c, confidence=prob) for b, c, prob in zip(pred_boxes, classes, proba)]
st.image(
img
# | z.make_grid((32, 32), fill='#ffffff44')
| z.plot_boxes(pred_boxes, width=6, color='red')
# | z.scale(0.5)
,
use_column_width=True
)
st.markdown(f"**Cantidad de botellas:** {len(pred_boxes)}")
found = Counter(clasificador.classes[b.cat] for b in pred_boxes)
for cat, qty in sorted(found.items()):
st.markdown(f"- **{cat.replace('__', ' ').replace('-', ' ')}:** {qty}")