From 61e3648c8f9b1f9791c3fae3730326cbc8573907 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Beno=C3=AEt=20M=C3=BCller?= <82948763+Benoit-Muller@users.noreply.github.com> Date: Thu, 22 Dec 2022 19:18:42 +0100 Subject: [PATCH] Delete notebook_Armelle.ipynb --- notebook_Armelle.ipynb | 885 ----------------------------------------- 1 file changed, 885 deletions(-) delete mode 100644 notebook_Armelle.ipynb diff --git a/notebook_Armelle.ipynb b/notebook_Armelle.ipynb deleted file mode 100644 index adbae0e..0000000 --- a/notebook_Armelle.ipynb +++ /dev/null @@ -1,885 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 2, - "id": "0fb1b20d", - "metadata": {}, - "outputs": [], - "source": [ - "import numpy as np\n", - "import matplotlib.pyplot as plt\n", - "from PIL import Image\n", - "import time\n", - "%matplotlib inline\n", - "%load_ext autoreload\n", - "%autoreload 2\n", - "\n", - "from sklearn.decomposition import TruncatedSVD\n", - "from sinkhorn import sinkhorn\n", - "from builders import image2array, array2cost, image2array, transfer_color, array2image" - ] - }, - { - "cell_type": "markdown", - "id": "d98089f4", - "metadata": {}, - "source": [ - "### Understanding how the truncated svd in scikitlearn works" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "id": "eb0b90e8", - "metadata": {}, - "outputs": [], - "source": [ - "X=np.random.rand(5,6)*100\n", - "X=np.round(X,0)" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "id": "f59343d9", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "array([[88., 70., 21., 53., 64., 91.],\n", - " [ 0., 73., 21., 59., 66., 42.],\n", - " [16., 36., 15., 24., 27., 68.],\n", - " [ 6., 48., 21., 21., 58., 89.],\n", - " [67., 51., 19., 14., 56., 99.]])" - ] - }, - "execution_count": 4, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "X" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "id": "5d63ff6c", - "metadata": {}, - "outputs": [], - "source": [ - "svd=TruncatedSVD(2) #choosing the rank r of the approximation, here a rank 3 approximation\n", - "US=svd.fit_transform(X)\n", - "V=svd.components_" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "id": "78a5f904", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "array([[ 72.25, 65.42, 23.05, 35.6 , 65.26, 108.73],\n", - " [ -9.5 , 68.19, 22.02, 52.96, 64.19, 54.05],\n", - " [ 26.11, 39.02, 13.35, 24.41, 38.16, 52.98],\n", - " [ 24.24, 57.11, 19.2 , 38.33, 55.21, 67.78],\n", - " [ 71.76, 50.8 , 18.27, 24.71, 51.38, 95.38]])" - ] - }, - "execution_count": 6, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "np.round(US@V,2) #approximation of X" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "id": "ad5eb63a", - "metadata": {}, - "outputs": [], - "source": [ - "u=np.array([[1,2,3,4,10]]).T\n", - "v=np.array([[5,5,5,6,6,6]]).T" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "id": "cf2529ab", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "0.0\n", - "0.0\n", - "0.0\n" - ] - } - ], - "source": [ - "start=time.time()\n", - "u*(US@V)*v.T\n", - "execution_time=time.time()-start\n", - "print(execution_time)\n", - "\n", - "start=time.time()\n", - "(u*US)@(V*v.T)\n", - "execution_time=time.time()-start\n", - "print(execution_time)\n", - "\n", - "start=time.time()\n", - "u*(US@(V*v.T))\n", - "execution_time=time.time()-start\n", - "print(execution_time)" - ] - }, - { - "cell_type": "markdown", - "id": "b03e7ccb", - "metadata": {}, - "source": [ - "### Initialisation of initial points " - ] - }, - { - "cell_type": "markdown", - "id": "c1a41bac", - "metadata": {}, - "source": [ - "## Test if low rank approximation improves the computational time" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "id": "35de34f0", - "metadata": {}, - "outputs": [], - "source": [ - "A=np.random.rand(100**2,100**2)\n", - "A=A*10000" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "id": "61861f30", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "10000" - ] - }, - "execution_count": 9, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "np.linalg.matrix_rank(A, tol=None, hermitian=False)" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "id": "2034b8c3", - "metadata": {}, - "outputs": [], - "source": [ - "ones=np.ones(100**2)" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "id": "e364616c", - "metadata": {}, - "outputs": [], - "source": [ - "start_time=time.time()\n", - "B=A@ones\n", - "end=time.time()-start_time" - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "id": "bd1a8344", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "0.05285954475402832" - ] - }, - "execution_count": 12, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "end" - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "id": "8e517c10", - "metadata": {}, - "outputs": [], - "source": [ - "svd=TruncatedSVD(500)\n", - "US=svd.fit_transform(A) \n", - "V=svd.components_\n", - "start=time.time()\n", - "US@(V@ones)\n", - "end=time.time()-start" - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "id": "c8061235", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "0.033837318420410156" - ] - }, - "execution_count": 14, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "end" - ] - }, - { - "cell_type": "markdown", - "id": "4f9fecf4", - "metadata": {}, - "source": [ - "It does really work !!!" - ] - }, - { - "cell_type": "markdown", - "id": "cd93277e", - "metadata": {}, - "source": [ - "### Sinkhorn where the rank of the rank approximation varies" - ] - }, - { - "cell_type": "code", - "execution_count": 15, - "id": "583474db", - "metadata": {}, - "outputs": [], - "source": [ - "# Points:\n", - "#n=100**2 #original number of datapoints, chose a lower one for runntime reasons \n", - "n=100\n", - "x = np.linspace(0,1,n)\n", - "y = np.linspace(0,1,n)\n", - "x=x[:,np.newaxis]\n", - "y=y[:,np.newaxis]\n", - "# Cost:\n", - "C = (x-y.T)**2\n", - "# entropy factor:\n", - "eta = 1 # il manque W <- W/eta dans l'algo alors garder eta=1\n", - "# (exact) Kernel:\n", - "Kmat = np.exp(-eta*C)\n", - "#def K(v): -----------#I have put them in comments since I won't use them \n", - "# return Kmat@v\n", - "#def Kt(v):\n", - "# return (Kmat.T)@v\n", - "# Target marginals:\n", - "p = np.ones((n,1))\n", - "p = p / np.sum(p)\n", - "q = np.ones((n,1))\n", - "q = q / np.sum(q) \n", - "# tolerance:\n", - "delta = 1e-15" - ] - }, - { - "cell_type": "code", - "execution_count": 16, - "id": "a7b57331", - "metadata": {}, - "outputs": [], - "source": [ - "rank_K=np.linalg.matrix_rank(Kmat, tol=None, hermitian=False)" - ] - }, - { - "cell_type": "code", - "execution_count": 17, - "id": "3d14912c", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "10" - ] - }, - "execution_count": 17, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "rank_K" - ] - }, - { - "cell_type": "code", - "execution_count": 18, - "id": "1d30fe7f", - "metadata": {}, - "outputs": [ - { - "ename": "NameError", - "evalue": "name 'list_time' is not defined", - "output_type": "error", - "traceback": [ - "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[1;31mNameError\u001b[0m Traceback (most recent call last)", - "Input \u001b[1;32mIn [18]\u001b[0m, in \u001b[0;36m\u001b[1;34m()\u001b[0m\n\u001b[1;32m----> 1\u001b[0m plt\u001b[38;5;241m.\u001b[39mplot(\u001b[43mlist_time\u001b[49m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m.-\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", - "\u001b[1;31mNameError\u001b[0m: name 'list_time' is not defined" - ] - } - ], - "source": [ - "plt.plot(list_time, \".-\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "2324eb50", - "metadata": {}, - "outputs": [], - "source": [ - "plt.plot(list_error, \".-\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "dce0d80b", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "b2323b36", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "9b088af8", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": 15, - "id": "aecbd678", - "metadata": {}, - "outputs": [], - "source": [ - "N = 100 # heigth of the image (= width)\n", - "n = N^2 # number of pixels\n", - "eta = 15 # lent si plus de 20\n", - "delta = 1e-15 # tolerance for sinkhorn()\n", - "img1_nbr = '3' #4\n", - "img2_nbr = '1' #2\n", - "img1_name = 'img' + img1_nbr + '_' + str(N) + '.jpg'\n", - "img2_name = 'img' + img2_nbr + '_' + str(N) + '.jpg'\n", - "img1 = image2array(img1_name) # source image\n", - "img2 = image2array(img2_name) # target image\n", - "C, p, q = array2cost(img1, img2) # cost and coupling marginals\n", - "Kmat = np.exp(-eta * C) # kernel to compute the sinkhorn projection \n", - "\n", - "def K_full(v):\n", - " ''' Kernel-vector matrix product'''\n", - " return Kmat @ v\n", - "def Kt_full(v):\n", - " ''' Transposed_kernel-vector matrix product'''\n", - " return (Kmat.T) @ v" - ] - }, - { - "cell_type": "code", - "execution_count": 16, - "id": "c0cbacbc", - "metadata": {}, - "outputs": [], - "source": [ - "k=200" - ] - }, - { - "cell_type": "code", - "execution_count": 17, - "id": "67c96456", - "metadata": {}, - "outputs": [], - "source": [ - "delta=10**(-10)\n", - "svd=TruncatedSVD(k)\n", - "US=svd.fit_transform(Kmat) \n", - "V=svd.components_\n", - "def K(v):\n", - " return US@(V@v)\n", - "def Kt(v):\n", - " return V.T@(US.T@v)\n", - "S_time=time.time()\n", - "[u,v,W,err]=sinkhorn(K,Kt,p,q,delta,maxtime=60)\n", - "end_time=time.time()-S_time\n", - "P=(u*US)@(V*v.T) #computing associated coupling P matrix" - ] - }, - { - "cell_type": "code", - "execution_count": 18, - "id": "3b522ec4", - "metadata": {}, - "outputs": [], - "source": [ - "k=500\n", - "delta=10**(-10)\n", - "svd=TruncatedSVD(k)\n", - "US_500=svd.fit_transform(Kmat) \n", - "V_500=svd.components_\n", - "def K_500(v):\n", - " return US_500@(V_500@v)\n", - "def Kt_500(v):\n", - " return V_500.T@(US_500.T@v)\n", - "S_time=time.time()\n", - "[u_500,v_500,W_500,err_500]=sinkhorn(K_500,Kt_500,p,q,delta,maxtime=60)\n", - "end_time_500=time.time()-S_time\n", - "P_500=(u_500*US_500)@(V_500*v_500.T) #computing associated coupling P matrix" - ] - }, - { - "cell_type": "code", - "execution_count": 19, - "id": "71744c6e", - "metadata": {}, - "outputs": [], - "source": [ - "k=400\n", - "delta=10**(-10)\n", - "svd=TruncatedSVD(k)\n", - "US_400=svd.fit_transform(Kmat) \n", - "V_400=svd.components_\n", - "def K_400(v):\n", - " return US_500@(V_500@v)\n", - "def Kt_400(v):\n", - " return V_500.T@(US_500.T@v)\n", - "S_time=time.time()\n", - "[u_400,v_400,W_400,err_400]=sinkhorn(K_400,Kt_400,p,q,delta,maxtime=60)\n", - "end_time_400=time.time()-S_time\n", - "P_400=(u_400*US_400)@(V_400*v_400.T) #computing associated coupling P matrix" - ] - }, - { - "cell_type": "code", - "execution_count": 28, - "id": "4cbcb1d9", - "metadata": {}, - "outputs": [], - "source": [ - "k=600\n", - "delta=10**(-10)\n", - "svd=TruncatedSVD(k)\n", - "US_600=svd.fit_transform(Kmat) \n", - "V_600=svd.components_\n", - "def K_600(v):\n", - " return US_600@(V_600@v)\n", - "def Kt_600(v):\n", - " return V_600.T@(US_600.T@v)\n", - "S_time=time.time()\n", - "[u_600,v_600,W_600,err_600]=sinkhorn(K_600,Kt_600,p,q,delta,maxtime=60)\n", - "end_time_600=time.time()-S_time\n", - "P_600=(u_600*US_600)@(V_600*v_600.T) #computing associated coupling P matrix" - ] - }, - { - "cell_type": "code", - "execution_count": 20, - "id": "8127faa1", - "metadata": {}, - "outputs": [], - "source": [ - "u_full,v_full,W_full,err_full = sinkhorn(K_full,Kt_full,p,q,delta,maxtime=60)\n", - "P_full = u_full*Kmat*v_full.T # the coupling" - ] - }, - { - "cell_type": "code", - "execution_count": 21, - "id": "cdfe5cbb", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "(10000, 10000)" - ] - }, - "execution_count": 21, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "P_full.shape" - ] - }, - { - "cell_type": "code", - "execution_count": 22, - "id": "d9d0432d", - "metadata": {}, - "outputs": [], - "source": [ - "def norm(P):\n", - " print(np.ones(P.shape[0]).T@P@np.ones(P.shape[0]))" - ] - }, - { - "cell_type": "code", - "execution_count": 43, - "id": "66433826", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "-15.131681494908829 -15.131681494889289 -15.131681494908408 -15.131681494908408 -15.13168149490895\n" - ] - } - ], - "source": [ - "print(W_full, W, W_400, W_500, W_600)" - ] - }, - { - "cell_type": "code", - "execution_count": 44, - "id": "7de31781", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "array(-15.13168149)" - ] - }, - "execution_count": 44, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "W_full" - ] - }, - { - "cell_type": "code", - "execution_count": 45, - "id": "c7cd793c", - "metadata": {}, - "outputs": [], - "source": [ - "W = (np.log(u).T @ (u*K(v)) + np.log(v).T @ (v*Kt(u)))" - ] - }, - { - "cell_type": "code", - "execution_count": 46, - "id": "d77b0965", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "array([[-15.13168149]])" - ] - }, - "execution_count": 46, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "W" - ] - }, - { - "cell_type": "code", - "execution_count": 50, - "id": "2052e8f8", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "array([[3.68113169]])" - ] - }, - "execution_count": 50, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "np.log(u).T @ (u*K(v))" - ] - }, - { - "cell_type": "code", - "execution_count": 52, - "id": "25bd825e", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "array([[-18.81281319]])" - ] - }, - "execution_count": 52, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "np.log(v).T @ (v*Kt(u))" - ] - }, - { - "cell_type": "code", - "execution_count": 53, - "id": "4af5d129", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "3.239284402616884e-10" - ] - }, - "execution_count": 53, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "np.min(v)" - ] - }, - { - "cell_type": "code", - "execution_count": 54, - "id": "dcfa9fb9", - "metadata": {}, - "outputs": [], - "source": [ - "u, s, vh = np.linalg.svd(Kmat, full_matrices=True)" - ] - }, - { - "cell_type": "code", - "execution_count": 55, - "id": "4dd9bb71", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "2270.9321040889154" - ] - }, - "execution_count": 55, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "np.linalg.norm( u@s@vh-Kmat,1)" - ] - }, - { - "cell_type": "code", - "execution_count": 89, - "id": "a5295617", - "metadata": {}, - "outputs": [], - "source": [ - "O=np.array([[1,2,3,4],[1,2,3,4],[3,4,5,6],[4,5,6,7]])" - ] - }, - { - "cell_type": "code", - "execution_count": 90, - "id": "19fbbb1b", - "metadata": {}, - "outputs": [], - "source": [ - "u, s, vh = np.linalg.svd(O, full_matrices=False)" - ] - }, - { - "cell_type": "code", - "execution_count": 91, - "id": "ae1a7b45", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "array([1.64316767e+01, 1.41421356e+00, 5.33113873e-16, 6.10226137e-17])" - ] - }, - "execution_count": 91, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "s" - ] - }, - { - "cell_type": "code", - "execution_count": 92, - "id": "61205684", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "array([[1., 2., 3., 4.],\n", - " [1., 2., 3., 4.],\n", - " [3., 4., 5., 6.],\n", - " [4., 5., 6., 7.]])" - ] - }, - "execution_count": 92, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "(u * s) @ vh" - ] - }, - { - "cell_type": "code", - "execution_count": 93, - "id": "39d03dc1", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "(4, 4)" - ] - }, - "execution_count": 93, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "u.shape" - ] - }, - { - "cell_type": "code", - "execution_count": 94, - "id": "ecae37d8", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "(4, 4)" - ] - }, - "execution_count": 94, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "(u*s).shape" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "c9742007", - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.8.13" - }, - "vscode": { - "interpreter": { - "hash": "233a531365d7bd5abb8382eb032c18c305e1c6b951add6f6a5c925475bc609cb" - } - } - }, - "nbformat": 4, - "nbformat_minor": 5 -}