Skip to content

Commit

Permalink
Add 2D embedding tab in the CC details view #18
Browse files Browse the repository at this point in the history
  • Loading branch information
blackwer committed Dec 6, 2023
1 parent 3326517 commit fedde8c
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 204 deletions.
260 changes: 78 additions & 182 deletions ManifoldEM/gui/eigen_views/cc_details_view.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import imageio
import itertools
import os
import pickle
import shutil
Expand All @@ -13,7 +12,7 @@

from PyQt5 import QtCore
from PyQt5.QtWidgets import (QMainWindow, QDialog, QTabWidget, QMessageBox, QPushButton, QSlider,
QLayout, QGridLayout, QProgressBar, QLabel, QComboBox, QCheckBox, QFrame)
QLayout, QGridLayout, QLabel, QComboBox, QCheckBox, QFrame)

from . import ClusterAvgMain

Expand All @@ -37,7 +36,7 @@ def _backup_restore(prd_index, backup=True):
# topos
srcdir = os.path.join(srcprefix, 'topos', f'PrD_{prd_index + 1}')
dstdir = os.path.join(dstprefix, 'topos', f'PrD_{prd_index + 1}')
shutil.copytree(srcdir, dstdir)
shutil.copytree(srcdir, dstdir, dirs_exist_ok=True)

# diff maps
srcfile = os.path.join(srcprefix, 'diff_maps', f'gC_trimmed_psi_prD_{prd_index}')
Expand Down Expand Up @@ -627,8 +626,6 @@ def slider_update(self): #update frame based on user slider position


class Manifold2dCanvas(QDialog):
progress1Changed = QtCore.Signal(int)
progress2Changed = QtCore.Signal(int)
data_changed = QtCore.pyqtSignal()

def __init__(self, prd_index: int, parent):
Expand All @@ -647,39 +644,13 @@ def __init__(self, prd_index: int, parent):
self.coordsX = [] #user X coordinate picks
self.coordsY = [] #user Y coordinate picks
self.connected = 0 #binary: 0=unconnected, 1=connected
self.pts_orig = []
self.pts_origX = []
self.pts_origY = []


self.figure = Figure(dpi=200)
self.ax = self.figure.add_subplot(111)
self.figure.set_tight_layout(True)
self.canvas = FigureCanvas(self.figure)

psi_file = p.get_psi_file(prd_index - 1) #current embedding
with open(psi_file, 'rb') as f:
data = pickle.load(f)
x = data['psi'][:, self.eigChoice1]
y = data['psi'][:, self.eigChoice2]

self.pts_orig = zip(x, y)
self.pts_origX = x
self.pts_origY = y
self.ax.scatter(self.pts_origX, self.pts_origY, s=1, c='#1f77b4') # plot initial data, C0

for tick in self.ax.xaxis.get_major_ticks():
tick.label1.set_fontsize(4)
for tick in self.ax.yaxis.get_major_ticks():
tick.label1.set_fontsize(4)
self.ax.get_xaxis().set_ticks([])
self.ax.get_yaxis().set_ticks([])
self.ax.set_title('Place points on the plot to encircle deviant cluster(s)', fontsize=3.5)
self.ax.set_xlabel(r'$\mathrm{\Psi}$%s' % (self.eigChoice1 + 1), fontsize=6)
self.ax.set_ylabel(r'$\mathrm{\Psi}$%s' % (self.eigChoice2 + 1), fontsize=6)
self.ax.autoscale()
self.canvas.mpl_connect('button_press_event', self.onclick)
self.canvas.draw() #refresh canvas


# canvas buttons:
self.btn_reset = QPushButton('Reset Plot')
Expand Down Expand Up @@ -711,10 +682,6 @@ def __init__(self, prd_index: int, parent):
self.btn_revert.setDefault(False)
self.btn_revert.setAutoDefault(False)

# disable reversion if manifold hasn't been reembedded
orig_embed = prd_index - 1 not in data_store.get_prds().reembed_ids
self.btn_revert.setDisabled(orig_embed)

self.btn_view = QPushButton('View Cluster')
self.btn_view.clicked.connect(self.view)
self.btn_view.setDisabled(True)
Expand All @@ -731,28 +698,34 @@ def __init__(self, prd_index: int, parent):
layout.addWidget(self.btn_rebed, 2, 4, 1, 1)
layout.addWidget(self.btn_revert, 2, 5, 1, 1)

self.progress1 = QProgressBar(minimum=0, maximum=100, value=0)
layout.addWidget(self.progress1, 3, 0, 1, 6)
self.progress1.show()

self.progress1Changed.connect(self.on_progress1Changed)
self.progress2Changed.connect(self.on_progress2Changed)

self.reload_psi_coords()
self.redraw()
self.setLayout(layout)


def reset(self):
self.pts_new = self.pts_orig
self.redraw()


def redraw(self):
self.connected = 0
self.btn_connect.setDisabled(True)
self.btn_remove.setDisabled(True)
self.btn_view.setDisabled(True)
self.btn_rebed.setDisabled(True)

# disable reversion if manifold hasn't been reembedded
orig_embed = self.prd_index - 1 not in data_store.get_prds().reembed_ids
self.btn_revert.setDisabled(orig_embed)

self.coordsX = []
self.coordsY = []

# redraw and resize figure:
self.ax.clear()
self.ax.scatter(self.pts_origX, self.pts_origY, s=1, c='#1f77b4')
x, y = zip(*self.pts_new)
self.ax.scatter(x, y, s=1, c='#1f77b4')

for tick in self.ax.xaxis.get_major_ticks():
tick.label1.set_fontsize(4)
Expand Down Expand Up @@ -793,40 +766,23 @@ def remove(self):
pts_newY = []

path = PlotPath(list(map(list, zip(self.coordsX, self.coordsY))), codes=None, closed=True, readonly=True)
inside = path.contains_points(np.dstack((self.pts_origX, self.pts_origY))[0].tolist(),
radius=1e-9)
x, y = zip(*self.pts_orig)
inside = path.contains_points(np.dstack((x, y))[0].tolist(), radius=1e-9)

for index, i in enumerate(inside):
if i == False:
pts_newX.append(self.pts_origX[index])
pts_newY.append(self.pts_origY[index])
self.pts_new = zip(pts_newX, pts_newY)
pts_newX.append(x[index])
pts_newY.append(y[index])


# crop out points, redraw and resize figure:
self.ax.clear()
self.ax.scatter(pts_newX, pts_newY, s=1, c='#1f77b4')
for tick in self.ax.xaxis.get_major_ticks():
tick.label1.set_fontsize(4)
for tick in self.ax.yaxis.get_major_ticks():
tick.label1.set_fontsize(4)
self.ax.get_xaxis().set_ticks([])
self.ax.get_yaxis().set_ticks([])
self.ax.set_title('Place points on the plot to encircle deviant cluster(s)', fontsize=3.5)
self.ax.set_xlabel(r'$\mathrm{\Psi}$%s' % (self.eigChoice1 + 1), fontsize=6)
self.ax.set_ylabel(r'$\mathrm{\Psi}$%s' % (self.eigChoice2 + 1), fontsize=6)
self.ax.autoscale()
self.canvas.draw()
self.btn_remove.setDisabled(True)
self.btn_view.setDisabled(True)
self.pts_new = list(zip(pts_newX, pts_newY))
self.redraw()
self.btn_rebed.setDisabled(False)


def rebed(self):
msg = 'Performing this action will recalculate the manifold \
embedding step for the current PD to include only the points shown.\
<br /><br />\
Do you want to proceed?'
msg = "Performing this action will recalculate the manifold embedding step for the current PD to include "\
"only the points shown.\n"\
"Do you want to proceed?"

box = QMessageBox(self)
box.setWindowTitle('ManifoldEM')
Expand All @@ -838,26 +794,26 @@ def rebed(self):
if box.exec_() == QMessageBox.No:
return

self.btn_reset.setDisabled(True)
self.btn_rebed.setDisabled(True)
self.parent_view.vid_tabs.setTabEnabled(0, False)
self.parent_view.vid_tabs.setTabEnabled(2, False)
self.parent_view.vid_tabs.setTabEnabled(3, False)
self.parent_view.vid_tabs.setTabEnabled(4, False)
self.parent_view.vid_tabs.setTabEnabled(5, False)

prds = data_store.get_prds()
if self.prd_index - 1 not in prds.reembed_ids: #only make a copy of current if this is user's first re-embedding
_backup_restore(self.prd_index - 1, backup=True) #makes copy in Topos/PrD and DiffMaps
prds.reembed_ids.add(self.prd_index - 1)
prds.save()

self.pts_orig, pts_orig_zip = itertools.tee(self.pts_orig)
self.pts_new, pts_new_zip = itertools.tee(self.pts_new)
embedd(list(self.pts_orig), list(self.pts_new), self.prd_index - 1) #updates all manifold files for PD

embedd(list(pts_orig_zip), list(pts_new_zip), self.prd_index - 1) #updates all manifold files for PD
self.redo_prd_analysis()
self.pts_orig = self.pts_new

self.start_task1()

def reload_psi_coords(self):
psi_file = p.get_psi_file(self.prd_index - 1) #current embedding
with open(psi_file, 'rb') as f:
data = pickle.load(f)

x = data['psi'][:, self.eigChoice1]
y = data['psi'][:, self.eigChoice2]
self.pts_orig = self.pts_new = list(zip(x, y))


def revert(self):
Expand All @@ -876,56 +832,14 @@ def revert(self):
if box.exec_() == QMessageBox.No:
return

self.btn_reset.setDisabled(False)
self.btn_rebed.setDisabled(True)
self.btn_revert.setDisabled(True)

self.parent_view.vid_tabs.setTabEnabled(0, False)
self.parent_view.vid_tabs.setTabEnabled(2, False)
self.parent_view.vid_tabs.setTabEnabled(3, False)
self.parent_view.vid_tabs.setTabEnabled(4, False)
self.parent_view.vid_tabs.setTabEnabled(5, False)

prds = data_store.get_prds()
prds.reembed_ids.discard(self.prd_index - 1)
prds.save()
_backup_restore(self.prd_index - 1, backup=False)

psi_file = p.get_psi_file(self.prd_index - 1) #current embedding
with open(psi_file, 'rb') as f:
data = pickle.load(f)

x = data['psi'][:, self.eigChoice1]
y = data['psi'][:, self.eigChoice2]

# redraw and resize figure:
self.ax.clear()
self.pts_orig = zip(x, y)
self.pts_origX = x
self.pts_origY = y
for i in self.pts_orig:
x, y = i
self.ax.scatter(x, y, s=1, c='#1f77b4') #plot initial data, C0

for i in self.pts_orig:
x, y = i
self.ax.scatter(x, y, s=1, c='#1f77b4') #plot initial data, C0
for tick in self.ax.xaxis.get_major_ticks():
tick.label1.set_fontsize(4)
for tick in self.ax.yaxis.get_major_ticks():
tick.label1.set_fontsize(4)
self.ax.get_xaxis().set_ticks([])
self.ax.get_yaxis().set_ticks([])
self.ax.set_title('Place points on the plot to encircle deviant cluster(s)', fontsize=3.5)
self.ax.set_xlabel(r'$\mathrm{\Psi}$%s' % (self.eigChoice1 + 1), fontsize=6)
self.ax.set_ylabel(r'$\mathrm{\Psi}$%s' % (self.eigChoice2 + 1), fontsize=6)
self.ax.autoscale()
self.canvas.draw()

self.parent_view.vid_tabs.setTabEnabled(0, True)
self.parent_view.vid_tabs.setTabEnabled(2, True)
self.parent_view.vid_tabs.setTabEnabled(3, True)
self.parent_view.vid_tabs.setTabEnabled(4, True)
self.parent_view.vid_tabs.setTabEnabled(5, True)
self.reload_psi_coords()
self.redraw()

msg = f'The manifold for PD {self.prd_index} has been successfully reverted.'
box = QMessageBox(self)
Expand All @@ -943,8 +857,8 @@ def revert(self):

def view(self): #view average of all images in encircled region
path = PlotPath(list(map(list, zip(self.coordsX, self.coordsY))), closed=True, codes=None, readonly=True)
inside_mask = path.contains_points(np.dstack((self.pts_origX, self.pts_origY))[0].tolist(),
radius=1e-9)
x, y = zip(*self.pts_orig)
inside_mask = path.contains_points(np.dstack((x, y))[0].tolist(), radius=1e-9)

idx_encircled = list(np.nonzero(inside_mask)[0])
print('Encircled Points:', len(idx_encircled))
Expand Down Expand Up @@ -974,59 +888,43 @@ def onclick(self, event):
self.btn_connect.setDisabled(False)


##########
# Task 1:
@QtCore.Slot()
def start_task1(self):
p.save() #send new GUI data to parameters file

task1 = threading.Thread(target=psiAnalysis.op, args=(self.progress1Changed, ))
task1.daemon = True
task1.start()


@QtCore.Slot(int)
def on_progress1Changed(self, val):
self.progress1.setValue(val / 2)
if val / 2 == 50:
self.start_task2()


##########
# Task 2:
@QtCore.Slot()
def start_task2(self):
def redo_prd_analysis(self):
p.save() #send new GUI data to parameters file
print(f"Re-running spectral analysis for prd {self.prd_index - 1}")
from ManifoldEM.psiAnalysis import psi_analysis_single
prd = self.prd_index - 1
dist_file = p.get_dist_file(prd)
psi_file = p.get_psi_file(prd)
psi2_file = p.get_psi2_file(prd)
EL_file = p.get_EL_file(prd)
psinums = list(range(p.num_psis))
senses = np.ones(p.num_psis)
psi_list = list(range(p.num_psis)) # list of incomplete psi values per PD
psi_analysis_single([dist_file, psi_file, psi2_file, EL_file, psinums, senses, prd, psi_list],
con_order_range=p.conOrderRange,
traj_name=p.trajName,
is_full=0,
psi_trunc=p.num_psiTrunc)

print(f"Re-making NLSA movie for prd {self.prd_index - 1}")
from ManifoldEM.NLSAmovie import movie
movie([prd], None, None, p.psi2_file, p.fps)

msg = f'The manifold for PD {self.prd_index} has been successfully re-embedded.'
box = QMessageBox(self)
box.setWindowTitle('ManifoldEM Re-embedding')
box.setText('<b>Re-embed Manifold</b>')
box.setIcon(QMessageBox.Warning)
box.setInformativeText(msg)
box.setStandardButtons(QMessageBox.Ok)
box.setDefaultButton(QMessageBox.Ok)
box.exec_()

task2 = threading.Thread(target=NLSAmovie.op, args=(self.progress2Changed, ))
task2.daemon = True
task2.start()


@QtCore.Slot(int)
def on_progress2Changed(self, val):
self.progress1.setValue(val / 2 + 50)
if (val / 2 + 50) == 100:
self.parent_view.vid_tabs.setTabEnabled(0, True)
self.parent_view.vid_tabs.setTabEnabled(2, True)
self.parent_view.vid_tabs.setTabEnabled(3, True)
self.parent_view.vid_tabs.setTabEnabled(4, True)
self.parent_view.vid_tabs.setTabEnabled(5, True)

msg = f'The manifold for PD {self.prd_index} has been successfully re-embedded.'
box = QMessageBox(self)
box.setWindowTitle('ManifoldEM Re-embedding')
box.setText('<b>Re-embed Manifold</b>')
box.setIcon(QMessageBox.Warning)
box.setFont(font_standard)
box.setInformativeText(msg)
box.setStandardButtons(QMessageBox.Ok)
box.setDefaultButton(QMessageBox.Ok)
box.exec_()

# force-update main GUI window (topos images)
self.data_changed.emit()
# force-update main GUI window (topos images)
self.data_changed.emit()

# reset the tab
self.redraw()


class _CCDetailsView(QMainWindow):
Expand All @@ -1040,7 +938,7 @@ def __init__(self, prd_index: int, psi_index: int):
def initUI(self):
gif_path = p.get_psi_gif(self.prd_index, self.psi_index)
self.vid_tab1 = VidCanvas(gif_path, parent=self)
self.vid_tab2 = QDialog(self) # Manifold2dCanvas(self.prd_index, self)
self.vid_tab2 = Manifold2dCanvas(self.prd_index, self)
self.vid_tab3 = QDialog(self) # Manifold3dCanvas(self)
self.vid_tab4 = ChronosCanvas(self.prd_index, self.psi_index, self)
self.vid_tab5 = PsiCanvas(self.prd_index, self.psi_index, self)
Expand Down Expand Up @@ -1074,7 +972,5 @@ def onTabChange(self, i):
self.vid_tab1.stop_movie()


# FIXME attach signals
def connect_signals(self, data_change_callback):
return
self.vid_tab2.data_changed.connect(data_change_callback)
Loading

0 comments on commit fedde8c

Please sign in to comment.