From d9e43181e227a25abf8cb1499402ecc690928643 Mon Sep 17 00:00:00 2001 From: Frederik Baerentsen Date: Thu, 29 Jan 2026 19:22:54 +0100 Subject: [PATCH] Initial Upload --- CVIssueCount.ini | 4 + CVIssueCount.png | Bin 0 -> 518 bytes CVIssueCount.py | 618 +++++++++++++++++ Package.ini | 4 + UserDict.py | 213 ++++++ __future__.py | 128 ++++ _abcoll.py | 695 ++++++++++++++++++++ _weakrefset.py | 204 ++++++ abc.py | 185 ++++++ base64.py | 364 +++++++++++ bisect.py | 92 +++ collections.py | 730 +++++++++++++++++++++ contextlib.py | 154 +++++ functools.py | 100 +++ genericpath.py | 113 ++++ hashlib.py | 221 +++++++ heapq.py | 485 ++++++++++++++ httplib.py | 1445 ++++++++++++++++++++++++++++++++++++++++ io.py | 90 +++ ipypulldom.py | 66 ++ keyword.py | 93 +++ linecache.py | 139 ++++ mimetools.py | 250 +++++++ ntpath.py | 550 ++++++++++++++++ nturl2path.py | 68 ++ os.py | 742 +++++++++++++++++++++ posixpath.py | 439 +++++++++++++ random.py | 910 ++++++++++++++++++++++++++ rfc822.py | 1016 ++++++++++++++++++++++++++++ socket.py | 577 ++++++++++++++++ ssl.py | 464 +++++++++++++ stat.py | 96 +++ string.py | 656 +++++++++++++++++++ struct.py | 3 + tempfile.py | 639 ++++++++++++++++++ textwrap.py | 429 ++++++++++++ types.py | 86 +++ urllib.py | 1637 ++++++++++++++++++++++++++++++++++++++++++++++ urllib2.py | 1488 +++++++++++++++++++++++++++++++++++++++++ urlparse.py | 428 ++++++++++++ warnings.py | 422 ++++++++++++ weakref.py | 458 +++++++++++++ xml2py.py | 121 ++++ 43 files changed, 17622 insertions(+) create mode 100644 CVIssueCount.ini create mode 100644 CVIssueCount.png create mode 100644 CVIssueCount.py create mode 100644 Package.ini create mode 100644 UserDict.py create mode 100644 __future__.py create mode 100644 _abcoll.py create mode 100644 _weakrefset.py create mode 100644 abc.py create mode 100644 base64.py create mode 100644 bisect.py create mode 100644 collections.py create mode 100644 contextlib.py create mode 100644 functools.py create mode 100644 genericpath.py create mode 100644 hashlib.py create mode 100644 heapq.py create mode 100644 httplib.py create mode 100644 io.py create mode 100644 ipypulldom.py create mode 100644 keyword.py create mode 100644 linecache.py create mode 100644 mimetools.py create mode 100644 ntpath.py create mode 100644 nturl2path.py create mode 100644 os.py create mode 100644 posixpath.py create mode 100644 random.py create mode 100644 rfc822.py create mode 100644 socket.py create mode 100644 ssl.py create mode 100644 stat.py create mode 100644 string.py create mode 100644 struct.py create mode 100644 tempfile.py create mode 100644 textwrap.py create mode 100644 types.py create mode 100644 urllib.py create mode 100644 urllib2.py create mode 100644 urlparse.py create mode 100644 warnings.py create mode 100644 weakref.py create mode 100644 xml2py.py diff --git a/CVIssueCount.ini b/CVIssueCount.ini new file mode 100644 index 0000000..c893ee4 --- /dev/null +++ b/CVIssueCount.ini @@ -0,0 +1,4 @@ +api_key=9718933c0a3aa4407d6995d0f3bb310f5ee977f2 +skip_api_if_not_in_db=false +always_use_api=false +api_delay=10 diff --git a/CVIssueCount.png b/CVIssueCount.png new file mode 100644 index 0000000000000000000000000000000000000000..2e02adc3fbadcf9bd865e9db111fdfd3513f380e GIT binary patch literal 518 zcmV+h0{Q)kP)>OFe9bM6NjEbgz~nS0N9X3lfY{e&n{qWpuj z16M#kpl(AjUnytIR!5*4C;*)Sbp~WZa9LqfKtIqU-=k6ehAIt0=oxU{e> zz#DK590Pl@KPOzS)B^P~;J)fx2kwE2Ac0D$AORa^1e`%GW#%Spz#S0t5~u=JfoG`| zX>$wsaYn$&kgUQkuqE45(q|Qz0A7I+DQA1!BA~fZ7pDcB^a9(!1~4GAw+fWWrKW#; z0NhBQ5wLMbigOEirAVMgmS{#n8EK0v0UI}~xNi$2Q-zLmj}<1 zi~BW!Y0K=31PUd19Gm?07*qo IM6N<$f&*C5+yDRo literal 0 HcmV?d00001 diff --git a/CVIssueCount.py b/CVIssueCount.py new file mode 100644 index 0000000..262c1fd --- /dev/null +++ b/CVIssueCount.py @@ -0,0 +1,618 @@ +# +#CVIssueCount.py +# +#Author: Frederik Baerentsen +# +#Description: Complete Story Arc with the Comic Vine info +# +#Versions: +# V1 0.1 First version +# V2 0.2 Added Progressbar, start and cancel buttons +# V2 0.3 Added issue count +# V3 0.4 Added local database support +# V3 0.5 Added settings dialog (API key, local DB options, API delay) +# V3 1.0 Performance improvements and stability fixes +# + +# $ cd "..\..\Program Files\ComicRack\ +# $ ComicRack.exe -ssc + +from __future__ import unicode_literals + +import clr, re, sys, os, urlparse, time + +from System.Diagnostics import Process +clr.AddReference("System.xml") +import System +from System import * +from System.IO import * +from System.Collections import * +from System.Threading import * +from System.Net import * +from System.Text import * + +clr.AddReference('System') +clr.AddReference("System.Windows.Forms") +clr.AddReference("System.Drawing") + +import System.Drawing +import System.Windows.Forms + +from System.Drawing import * +from System.Windows.Forms import * + +clr.AddReference('System.Drawing') +from System.Drawing import Point, Size, ContentAlignment, Color, SystemColors, Icon +clr.AddReference('System.Threading') +from System.Threading import * +from datetime import datetime +import ssl, urllib +from System.ComponentModel import BackgroundWorker + + +stop=False + +# ============== Settings Management ============== + +SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) +SETTINGS_FILE = os.path.join(SCRIPT_DIR, "CVIssueCount.ini") + +class Settings: + def __init__(self): + self.api_key = "" + self.skip_api_if_not_in_db = False + self.always_use_api = False + self.api_delay = 10 # seconds between API calls, 0 to disable + self.load_from_file() + + def load_from_file(self): + """Load settings from INI file.""" + if os.path.exists(SETTINGS_FILE): + try: + with open(SETTINGS_FILE, 'r') as f: + for line in f: + line = line.strip() + if '=' in line: + key, value = line.split('=', 1) + key = key.strip() + value = value.strip() + if key == 'api_key': + self.api_key = value + elif key == 'skip_api_if_not_in_db': + self.skip_api_if_not_in_db = value.lower() == 'true' + elif key == 'always_use_api': + self.always_use_api = value.lower() == 'true' + elif key == 'api_delay': + try: + self.api_delay = int(value) + except: + self.api_delay = 10 + except Exception as e: + print "Error loading settings: " + str(e) + + def save_to_file(self): + """Save settings to INI file.""" + try: + with open(SETTINGS_FILE, 'w') as f: + f.write("api_key=" + self.api_key + "\n") + f.write("skip_api_if_not_in_db=" + str(self.skip_api_if_not_in_db).lower() + "\n") + f.write("always_use_api=" + str(self.always_use_api).lower() + "\n") + f.write("api_delay=" + str(self.api_delay) + "\n") + except Exception as e: + print "Error saving settings: " + str(e) + +# Global settings instance +settings = Settings() + +# ============== SQLite Initialization ============== + +_sqlite_initialized = False +_sqlite_available = False +_SQLiteConnection = None + +def _init_sqlite(): + """Initialize SQLite once at module load. Returns True if successful.""" + global _sqlite_initialized, _sqlite_available, _SQLiteConnection + + if _sqlite_initialized: + return _sqlite_available + + _sqlite_initialized = True + + try: + script_dir = os.path.dirname(os.path.abspath(__file__)) + sqlite_dll = Path.Combine(script_dir, 'System.Data.SQLite.dll') + + if not File.Exists(sqlite_dll): + print "System.Data.SQLite.dll not found at: " + sqlite_dll + return False + + # Set DLL search path to find e_sqlite3.dll (do this once) + if System.IntPtr.Size == 8: + native_path = Path.Combine(script_dir, 'x64') + else: + native_path = Path.Combine(script_dir, 'x86') + + current_path = System.Environment.GetEnvironmentVariable("PATH") + if native_path not in current_path: + System.Environment.SetEnvironmentVariable("PATH", native_path + ";" + current_path) + + clr.AddReferenceToFileAndPath(sqlite_dll) + from System.Data.SQLite import SQLiteConnection + _SQLiteConnection = SQLiteConnection + + _sqlite_available = True + print "SQLite initialized successfully" + return True + + except Exception as e: + print "SQLite initialization error: " + str(e) + return False + +# Initialize SQLite at module load +_init_sqlite() + +# ============== Settings Dialog ============== + +class SettingsForm(Form): + def __init__(self): + self.InitializeComponent() + self.LoadSettings() + + def InitializeComponent(self): + self.SuspendLayout() + + # API Key Label + self._apiKeyLabel = Label() + self._apiKeyLabel.Text = "ComicVine API Key:" + self._apiKeyLabel.Location = Point(12, 15) + self._apiKeyLabel.Size = Size(120, 20) + + # API Key TextBox + self._apiKeyTextBox = TextBox() + self._apiKeyTextBox.Location = Point(140, 12) + self._apiKeyTextBox.Size = Size(280, 20) + self._apiKeyTextBox.UseSystemPasswordChar = True + + # Show/Hide API Key button + self._showApiKey = Button() + self._showApiKey.Text = "Show" + self._showApiKey.Location = Point(425, 10) + self._showApiKey.Size = Size(50, 24) + self._showApiKey.Click += self.ToggleApiKeyVisibility + + # Skip API checkbox + self._skipApiCheckBox = CheckBox() + self._skipApiCheckBox.Text = "Skip API if not found in local database (local DB only)" + self._skipApiCheckBox.Location = Point(12, 50) + self._skipApiCheckBox.Size = Size(400, 24) + self._skipApiCheckBox.CheckedChanged += self.OnSkipApiChanged + + # Always use API checkbox + self._alwaysApiCheckBox = CheckBox() + self._alwaysApiCheckBox.Text = "Always use API (skip local database)" + self._alwaysApiCheckBox.Location = Point(12, 80) + self._alwaysApiCheckBox.Size = Size(400, 24) + self._alwaysApiCheckBox.CheckedChanged += self.OnAlwaysApiChanged + + # API Delay Label + self._apiDelayLabel = Label() + self._apiDelayLabel.Text = "API delay (seconds):" + self._apiDelayLabel.Location = Point(12, 115) + self._apiDelayLabel.Size = Size(120, 20) + + # API Delay TextBox + self._apiDelayTextBox = TextBox() + self._apiDelayTextBox.Location = Point(140, 112) + self._apiDelayTextBox.Size = Size(50, 20) + + # API Delay hint + self._apiDelayHint = Label() + self._apiDelayHint.Text = "(0 = no delay)" + self._apiDelayHint.Location = Point(195, 115) + self._apiDelayHint.Size = Size(100, 20) + self._apiDelayHint.ForeColor = Color.Gray + + # Info label + self._infoLabel = Label() + self._infoLabel.Text = "Note: Local database (localcv.db) must be in the plugin folder." + self._infoLabel.Location = Point(12, 145) + self._infoLabel.Size = Size(460, 20) + self._infoLabel.ForeColor = Color.Gray + + # Save Button + self._saveButton = Button() + self._saveButton.Text = "Save" + self._saveButton.Location = Point(310, 175) + self._saveButton.Size = Size(80, 28) + self._saveButton.Click += self.SaveClicked + + # Cancel Button + self._cancelButton = Button() + self._cancelButton.Text = "Cancel" + self._cancelButton.Location = Point(395, 175) + self._cancelButton.Size = Size(80, 28) + self._cancelButton.Click += self.CancelClicked + + # Form settings + self.ClientSize = Size(490, 215) + self.Text = "CVIssueCount Settings" + self.FormBorderStyle = FormBorderStyle.FixedDialog + self.MaximizeBox = False + self.MinimizeBox = False + self.StartPosition = FormStartPosition.CenterParent + self.AcceptButton = self._saveButton + self.CancelButton = self._cancelButton + + # Add controls + self.Controls.Add(self._apiKeyLabel) + self.Controls.Add(self._apiKeyTextBox) + self.Controls.Add(self._showApiKey) + self.Controls.Add(self._skipApiCheckBox) + self.Controls.Add(self._alwaysApiCheckBox) + self.Controls.Add(self._apiDelayLabel) + self.Controls.Add(self._apiDelayTextBox) + self.Controls.Add(self._apiDelayHint) + self.Controls.Add(self._infoLabel) + self.Controls.Add(self._saveButton) + self.Controls.Add(self._cancelButton) + + self.ResumeLayout(False) + + def LoadSettings(self): + self._apiKeyTextBox.Text = settings.api_key + self._skipApiCheckBox.Checked = settings.skip_api_if_not_in_db + self._alwaysApiCheckBox.Checked = settings.always_use_api + self._apiDelayTextBox.Text = str(settings.api_delay) + + def ToggleApiKeyVisibility(self, sender, e): + self._apiKeyTextBox.UseSystemPasswordChar = not self._apiKeyTextBox.UseSystemPasswordChar + self._showApiKey.Text = "Hide" if not self._apiKeyTextBox.UseSystemPasswordChar else "Show" + + def OnSkipApiChanged(self, sender, e): + if self._skipApiCheckBox.Checked: + self._alwaysApiCheckBox.Checked = False + + def OnAlwaysApiChanged(self, sender, e): + if self._alwaysApiCheckBox.Checked: + self._skipApiCheckBox.Checked = False + + def SaveClicked(self, sender, e): + settings.api_key = self._apiKeyTextBox.Text + settings.skip_api_if_not_in_db = self._skipApiCheckBox.Checked + settings.always_use_api = self._alwaysApiCheckBox.Checked + try: + settings.api_delay = int(self._apiDelayTextBox.Text) + if settings.api_delay < 0: + settings.api_delay = 0 + except: + settings.api_delay = 10 + settings.save_to_file() + self.DialogResult = DialogResult.OK + self.Close() + + def CancelClicked(self, sender, e): + self.DialogResult = DialogResult.Cancel + self.Close() + +# ============== Main Entry Points ============== + +#@Name CVIssueCount +#@Hook Books +#@Image CVIssueCount.png +#@Key CVIssueCount +def CVIssueCount(books): + """Main entry point - called from context menu on Books.""" + if books: + f = CVIssueCountForm(books) + r = f.ShowDialog() + else: + MessageBox.Show("No books selected") + return + +#@Key CVIssueCount +#@Hook ConfigScript +def CVIssueCountSettings(): + """Entry point for ConfigScript hook - shows settings from File menu.""" + f = SettingsForm() + f.ShowDialog() + +# ============== Main Form ============== + +class CVIssueCountForm(Form): + def __init__(self, books): + self.InitializeComponent() + self.books = books + self.percentage = 1.0/len(books)*100 + self.progresspercent = 0.0 + self.stop = False + + def InitializeComponent(self): + self._label1 = System.Windows.Forms.Label() + self._Okay = System.Windows.Forms.Button() + self._Cancel = System.Windows.Forms.Button() + self._Settings = System.Windows.Forms.Button() + self._progress = System.Windows.Forms.ProgressBar() + self.worker = BackgroundWorker() + self.SuspendLayout() + + self.worker.WorkerSupportsCancellation = True + self.worker.WorkerReportsProgress = True + self.worker.DoWork += self.DoWork + self.worker.ProgressChanged += self.ReportProgress + self.worker.RunWorkerCompleted += self.WorkerCompleted + + self._progress.Location = System.Drawing.Point(12, 114) + self._progress.Size = System.Drawing.Size(456, 17) + self._progress.TabIndex = 4 + # + # label1 + # + self._label1.AutoSize = True + self._label1.Location = System.Drawing.Point(12, 53) + self._label1.Name = "label1" + self._label1.Size = System.Drawing.Size(16, 13) + self._label1.TabIndex = 3 + self._label1.Text = "Start searching Comicvine for issue counts using the 'Start' button" + # + # Settings Button + # + self._Settings.Location = System.Drawing.Point(12, 137) + self._Settings.Name = "Settings" + self._Settings.Size = System.Drawing.Size(75, 23) + self._Settings.TabIndex = 7 + self._Settings.Text = "Settings" + self._Settings.UseVisualStyleBackColor = True + self._Settings.Click += self.SettingsClicked + # + # Okay + # + self._Okay.Location = System.Drawing.Point(312, 137) + self._Okay.Name = "Okay" + self._Okay.Size = System.Drawing.Size(75, 23) + self._Okay.TabIndex = 5 + self._Okay.Text = "Start" + self._Okay.UseVisualStyleBackColor = True + self._Okay.Click += self.OkayClicked + # + # Cancel + # + self._Cancel.Location = System.Drawing.Point(393, 137) + self._Cancel.Name = "Cancel" + self._Cancel.Size = System.Drawing.Size(75, 23) + self._Cancel.TabIndex = 6 + self._Cancel.Text = "Cancel" + self._Cancel.Enabled = False + self._Cancel.UseVisualStyleBackColor = True + self._Cancel.Click += self.CancelClicked + + self.ClientSize = System.Drawing.Size(483, 170) + self.Controls.Add(self._Settings) + self.Controls.Add(self._Cancel) + self.Controls.Add(self._Okay) + self.Controls.Add(self._progress) + self.Controls.Add(self._label1) + self.Name = "CVIssueCount" + self.Text = "Get issue Count from CV" + self.MinimizeBox = False + self.MaximizeBox = False + self.AcceptButton = self._Okay + self.FormBorderStyle = FormBorderStyle.FixedDialog + self.StartPosition = FormStartPosition.CenterParent + + self.ResumeLayout(False) + self.PerformLayout() + self.FormClosing += self.CheckClosing + + + + def CheckClosing(self, sender, e): + if self.worker.IsBusy: + self.worker.CancelAsync() + e.Cancel = True + #print "close" + + def SettingsClicked(self, sender, e): + f = SettingsForm() + f.ShowDialog() + + def OkayClicked(self, sender, e): + if self._Okay.Text == "Start": + if self.worker.IsBusy == False: + self.worker.RunWorkerAsync() + self._Cancel.Enabled = True + self._Okay.Enabled = False + self._Settings.Enabled = False + #else: + #self.DialogResult = DialogResult.OK + + def CancelClicked(self, sender, e): + if self.worker.IsBusy: + self.worker.CancelAsync() + self._Cancel.Enabled = False + + def DoWork(self, sender, e): + e.Result = getIssueCount(sender,self.books) + + def ReportProgress(self, sender, e): + self.progresspercent = self.percentage*e.ProgressPercentage + self._progress.Value = int(round(self.progresspercent)) + self._label1.Text = "Searching for " + e.UserState.ToString() + + + def WorkerCompleted(self, sender, e): + self._Okay.Text = "Done" + self.Close() + self.Dispose() + +# ============== Database Functions ============== + +def getIssueCountFromDB(volume_id, db_path): + """Try to get issue count from local database. Returns None if not found.""" + if not _sqlite_available: + return None + + conn = None + try: + conn = _SQLiteConnection("Data Source=" + db_path + ";Version=3;Read Only=True;BusyTimeout=5000;") + conn.Open() + cmd = conn.CreateCommand() + cmd.CommandText = "SELECT count_of_issues FROM cv_volume WHERE id = " + str(int(volume_id)) + result = cmd.ExecuteScalar() + if result is not None: + return int(result) + except Exception as e: + print "DB lookup error: " + str(e) + finally: + if conn is not None: + try: + conn.Close() + except: + pass + return None + +# ============== Main Logic ============== + +def getIssueCount(worker,books): + + # Get API key from settings + API_KEY = settings.api_key if settings.api_key else "" + + # Check for local database + script_dir = os.path.dirname(os.path.abspath(__file__)) + db_path = os.path.join(script_dir, "localcv.db") + db_exists = os.path.exists(db_path) + + # Determine what sources to use based on settings + use_local_db = db_exists and not settings.always_use_api + use_api = not settings.skip_api_if_not_in_db + + if settings.always_use_api: + print "Settings: Always use API (skipping local DB)" + elif settings.skip_api_if_not_in_db: + print "Settings: Skip API if not in local DB" + if not db_exists: + print "Warning: Local database not found!" + elif use_local_db: + print "Using local database: " + db_path + else: + print "Local database not found, using API" + + all_books_original = ComicRack.App.GetLibraryBooks() + + # Build a lookup dictionary: volume_id -> list of books (do this ONCE) + print "Building volume lookup index..." + volume_to_books = {} + for book in all_books_original: + vol = book.GetCustomValue("comicvine_volume") + if vol: + if vol not in volume_to_books: + volume_to_books[vol] = [] + volume_to_books[vol].append(book) + print "Index built: " + str(len(volume_to_books)) + " volumes" + + stop=False + CheckedVolumes=list() + IssueCountList=list() + IssueCount=0 + + failed = 0 + count = 0 + report = System.Text.StringBuilder() + for book in books: + if worker.CancellationPending: + return failed,report.ToString() + + volume = book.GetCustomValue("comicvine_volume") + seriesVolume = volume + bookinfo = "#" + str(book.Number) + " " + book.Series + " (" + str(book.Volume) + ")" + + count += 1 + worker.ReportProgress(count,bookinfo) + + + now = datetime.now() + print str(now.strftime("%m/%d/%Y, %H:%M:%S")) + + if volume not in CheckedVolumes: + IssueCount = None + + # Try local database first (unless always_use_api is set) + if use_local_db: + IssueCount = getIssueCountFromDB(volume, db_path) + if IssueCount is not None: + print str(volume) + "'s count is " + str(IssueCount) + " (from local DB)" + report.Append(str(IssueCount)) + + # Fall back to API if not found in local DB (unless skip_api_if_not_in_db is set) + if IssueCount is None and use_api: + QUERY = "https://comicvine.gamespot.com/api/volume/4050-"+ volume +"/?api_key=" + API_KEY + "&format=xml&field_list=count_of_issues" + print QUERY + data = _read_url(QUERY.encode('utf-8')) + + if settings.api_delay > 0: + time.sleep(settings.api_delay) + + doc = System.Xml.XmlDocument() + doc.LoadXml(data) + elemList = doc.GetElementsByTagName("count_of_issues") + + for i in elemList: + IssueCount = int(i.InnerXml) + if i.InnerText == "": + failed += 1 + report.Append(i.InnerText) + else: + report.Append(i.InnerText) + print str(volume) + "'s count is " + str(IssueCount) + " (from API)" + elif IssueCount is None: + print str(volume) + " not found in local DB (API skipped per settings)" + failed += 1 + + CheckedVolumes.append(volume) + IssueCountList.append(IssueCount if IssueCount else 0) + + # Update all books with this volume using the pre-built index (fast O(1) lookup) + if seriesVolume in volume_to_books: + for b in volume_to_books[seriesVolume]: + if b.Number.isnumeric: + b.Count = IssueCount if IssueCount else 0 + b.SetCustomValue("comicvine_issue_count",str(IssueCount if IssueCount else 0)) + + return failed,report.ToString() + +def _read_url(url): + + page = '' + + requestUri = url + + ServicePointManager.SecurityProtocol = SecurityProtocolType.Tls12 | SecurityProtocolType.Ssl3; + + Req = HttpWebRequest.Create(requestUri) + Req.Timeout = 60000 + Req.UserAgent = "CVIssueCount/" + " (https://gitea.baerentsen.space/FrederikBaerentsen/ComicRack_Scripts/src/branch/master/CVIssueCount)" + Req.AutomaticDecompression = DecompressionMethods.Deflate | DecompressionMethods.GZip + + #Req.Referer = requestUri + Req.Accept = 'text/html,application/xhtml+xml,application/xml;q=0.9,image/webp,image/apng,*/*;q=0.8' + Req.Headers.Add('Accept-Language','en-US,en;q=0.9,it;q=0.8,fr;q=0.7,de-DE;q=0.6,de;q=0.5') + + Req.KeepAlive = True + webresponse = Req.GetResponse() + + a = webresponse.Cookies + + inStream = webresponse.GetResponseStream() + encode = Encoding.GetEncoding("utf-8") + ReadStream = StreamReader(inStream, encode) + page = ReadStream.ReadToEnd() + + + try: + inStream.Close() + webresponse.Close() + except: + pass + + return page diff --git a/Package.ini b/Package.ini new file mode 100644 index 0000000..75e6e6b --- /dev/null +++ b/Package.ini @@ -0,0 +1,4 @@ +Name=CV Issue Count +Author=Frederik Baerentsen +Version=3.1 +Description=Complete series count with the Comic Vine issue count \ No newline at end of file diff --git a/UserDict.py b/UserDict.py new file mode 100644 index 0000000..732b327 --- /dev/null +++ b/UserDict.py @@ -0,0 +1,213 @@ +"""A more or less complete user-defined wrapper around dictionary objects.""" + +class UserDict: + def __init__(*args, **kwargs): + if not args: + raise TypeError("descriptor '__init__' of 'UserDict' object " + "needs an argument") + self = args[0] + args = args[1:] + if len(args) > 1: + raise TypeError('expected at most 1 arguments, got %d' % len(args)) + if args: + dict = args[0] + elif 'dict' in kwargs: + dict = kwargs.pop('dict') + import warnings + warnings.warn("Passing 'dict' as keyword argument is " + "deprecated", PendingDeprecationWarning, + stacklevel=2) + else: + dict = None + self.data = {} + if dict is not None: + self.update(dict) + if len(kwargs): + self.update(kwargs) + def __repr__(self): return repr(self.data) + def __cmp__(self, dict): + if isinstance(dict, UserDict): + return cmp(self.data, dict.data) + else: + return cmp(self.data, dict) + __hash__ = None # Avoid Py3k warning + def __len__(self): return len(self.data) + def __getitem__(self, key): + if key in self.data: + return self.data[key] + if hasattr(self.__class__, "__missing__"): + return self.__class__.__missing__(self, key) + raise KeyError(key) + def __setitem__(self, key, item): self.data[key] = item + def __delitem__(self, key): del self.data[key] + def clear(self): self.data.clear() + def copy(self): + if self.__class__ is UserDict: + return UserDict(self.data.copy()) + import copy + data = self.data + try: + self.data = {} + c = copy.copy(self) + finally: + self.data = data + c.update(self) + return c + def keys(self): return self.data.keys() + def items(self): return self.data.items() + def iteritems(self): return self.data.iteritems() + def iterkeys(self): return self.data.iterkeys() + def itervalues(self): return self.data.itervalues() + def values(self): return self.data.values() + def has_key(self, key): return key in self.data + def update(*args, **kwargs): + if not args: + raise TypeError("descriptor 'update' of 'UserDict' object " + "needs an argument") + self = args[0] + args = args[1:] + if len(args) > 1: + raise TypeError('expected at most 1 arguments, got %d' % len(args)) + if args: + dict = args[0] + elif 'dict' in kwargs: + dict = kwargs.pop('dict') + import warnings + warnings.warn("Passing 'dict' as keyword argument is deprecated", + PendingDeprecationWarning, stacklevel=2) + else: + dict = None + if dict is None: + pass + elif isinstance(dict, UserDict): + self.data.update(dict.data) + elif isinstance(dict, type({})) or not hasattr(dict, 'items'): + self.data.update(dict) + else: + for k, v in dict.items(): + self[k] = v + if len(kwargs): + self.data.update(kwargs) + def get(self, key, failobj=None): + if key not in self: + return failobj + return self[key] + def setdefault(self, key, failobj=None): + if key not in self: + self[key] = failobj + return self[key] + def pop(self, key, *args): + return self.data.pop(key, *args) + def popitem(self): + return self.data.popitem() + def __contains__(self, key): + return key in self.data + @classmethod + def fromkeys(cls, iterable, value=None): + d = cls() + for key in iterable: + d[key] = value + return d + +class IterableUserDict(UserDict): + def __iter__(self): + return iter(self.data) + +import _abcoll +_abcoll.MutableMapping.register(IterableUserDict) + + +class DictMixin: + # Mixin defining all dictionary methods for classes that already have + # a minimum dictionary interface including getitem, setitem, delitem, + # and keys. Without knowledge of the subclass constructor, the mixin + # does not define __init__() or copy(). In addition to the four base + # methods, progressively more efficiency comes with defining + # __contains__(), __iter__(), and iteritems(). + + # second level definitions support higher levels + def __iter__(self): + for k in self.keys(): + yield k + def has_key(self, key): + try: + self[key] + except KeyError: + return False + return True + def __contains__(self, key): + return self.has_key(key) + + # third level takes advantage of second level definitions + def iteritems(self): + for k in self: + yield (k, self[k]) + def iterkeys(self): + return self.__iter__() + + # fourth level uses definitions from lower levels + def itervalues(self): + for _, v in self.iteritems(): + yield v + def values(self): + return [v for _, v in self.iteritems()] + def items(self): + return list(self.iteritems()) + def clear(self): + for key in self.keys(): + del self[key] + def setdefault(self, key, default=None): + try: + return self[key] + except KeyError: + self[key] = default + return default + def pop(self, key, *args): + if len(args) > 1: + raise TypeError, "pop expected at most 2 arguments, got "\ + + repr(1 + len(args)) + try: + value = self[key] + except KeyError: + if args: + return args[0] + raise + del self[key] + return value + def popitem(self): + try: + k, v = self.iteritems().next() + except StopIteration: + raise KeyError, 'container is empty' + del self[k] + return (k, v) + def update(self, other=None, **kwargs): + # Make progressively weaker assumptions about "other" + if other is None: + pass + elif hasattr(other, 'iteritems'): # iteritems saves memory and lookups + for k, v in other.iteritems(): + self[k] = v + elif hasattr(other, 'keys'): + for k in other.keys(): + self[k] = other[k] + else: + for k, v in other: + self[k] = v + if kwargs: + self.update(kwargs) + def get(self, key, default=None): + try: + return self[key] + except KeyError: + return default + def __repr__(self): + return repr(dict(self.iteritems())) + def __cmp__(self, other): + if other is None: + return 1 + if isinstance(other, DictMixin): + other = dict(other.iteritems()) + return cmp(dict(self.iteritems()), other) + def __len__(self): + return len(self.keys()) diff --git a/__future__.py b/__future__.py new file mode 100644 index 0000000..e0996eb --- /dev/null +++ b/__future__.py @@ -0,0 +1,128 @@ +"""Record of phased-in incompatible language changes. + +Each line is of the form: + + FeatureName = "_Feature(" OptionalRelease "," MandatoryRelease "," + CompilerFlag ")" + +where, normally, OptionalRelease < MandatoryRelease, and both are 5-tuples +of the same form as sys.version_info: + + (PY_MAJOR_VERSION, # the 2 in 2.1.0a3; an int + PY_MINOR_VERSION, # the 1; an int + PY_MICRO_VERSION, # the 0; an int + PY_RELEASE_LEVEL, # "alpha", "beta", "candidate" or "final"; string + PY_RELEASE_SERIAL # the 3; an int + ) + +OptionalRelease records the first release in which + + from __future__ import FeatureName + +was accepted. + +In the case of MandatoryReleases that have not yet occurred, +MandatoryRelease predicts the release in which the feature will become part +of the language. + +Else MandatoryRelease records when the feature became part of the language; +in releases at or after that, modules no longer need + + from __future__ import FeatureName + +to use the feature in question, but may continue to use such imports. + +MandatoryRelease may also be None, meaning that a planned feature got +dropped. + +Instances of class _Feature have two corresponding methods, +.getOptionalRelease() and .getMandatoryRelease(). + +CompilerFlag is the (bitfield) flag that should be passed in the fourth +argument to the builtin function compile() to enable the feature in +dynamically compiled code. This flag is stored in the .compiler_flag +attribute on _Future instances. These values must match the appropriate +#defines of CO_xxx flags in Include/compile.h. + +No feature line is ever to be deleted from this file. +""" + +all_feature_names = [ + "nested_scopes", + "generators", + "division", + "absolute_import", + "with_statement", + "print_function", + "unicode_literals", +] + +__all__ = ["all_feature_names"] + all_feature_names + +# The CO_xxx symbols are defined here under the same names used by +# compile.h, so that an editor search will find them here. However, +# they're not exported in __all__, because they don't really belong to +# this module. +CO_NESTED = 0x0010 # nested_scopes +CO_GENERATOR_ALLOWED = 0 # generators (obsolete, was 0x1000) +CO_FUTURE_DIVISION = 0x2000 # division +CO_FUTURE_ABSOLUTE_IMPORT = 0x4000 # perform absolute imports by default +CO_FUTURE_WITH_STATEMENT = 0x8000 # with statement +CO_FUTURE_PRINT_FUNCTION = 0x10000 # print function +CO_FUTURE_UNICODE_LITERALS = 0x20000 # unicode string literals + +class _Feature: + def __init__(self, optionalRelease, mandatoryRelease, compiler_flag): + self.optional = optionalRelease + self.mandatory = mandatoryRelease + self.compiler_flag = compiler_flag + + def getOptionalRelease(self): + """Return first release in which this feature was recognized. + + This is a 5-tuple, of the same form as sys.version_info. + """ + + return self.optional + + def getMandatoryRelease(self): + """Return release in which this feature will become mandatory. + + This is a 5-tuple, of the same form as sys.version_info, or, if + the feature was dropped, is None. + """ + + return self.mandatory + + def __repr__(self): + return "_Feature" + repr((self.optional, + self.mandatory, + self.compiler_flag)) + +nested_scopes = _Feature((2, 1, 0, "beta", 1), + (2, 2, 0, "alpha", 0), + CO_NESTED) + +generators = _Feature((2, 2, 0, "alpha", 1), + (2, 3, 0, "final", 0), + CO_GENERATOR_ALLOWED) + +division = _Feature((2, 2, 0, "alpha", 2), + (3, 0, 0, "alpha", 0), + CO_FUTURE_DIVISION) + +absolute_import = _Feature((2, 5, 0, "alpha", 1), + (3, 0, 0, "alpha", 0), + CO_FUTURE_ABSOLUTE_IMPORT) + +with_statement = _Feature((2, 5, 0, "alpha", 1), + (2, 6, 0, "alpha", 0), + CO_FUTURE_WITH_STATEMENT) + +print_function = _Feature((2, 6, 0, "alpha", 2), + (3, 0, 0, "alpha", 0), + CO_FUTURE_PRINT_FUNCTION) + +unicode_literals = _Feature((2, 6, 0, "alpha", 2), + (3, 0, 0, "alpha", 0), + CO_FUTURE_UNICODE_LITERALS) diff --git a/_abcoll.py b/_abcoll.py new file mode 100644 index 0000000..b643692 --- /dev/null +++ b/_abcoll.py @@ -0,0 +1,695 @@ +# Copyright 2007 Google, Inc. All Rights Reserved. +# Licensed to PSF under a Contributor Agreement. + +"""Abstract Base Classes (ABCs) for collections, according to PEP 3119. + +DON'T USE THIS MODULE DIRECTLY! The classes here should be imported +via collections; they are defined here only to alleviate certain +bootstrapping issues. Unit tests are in test_collections. +""" + +from abc import ABCMeta, abstractmethod +import sys + +__all__ = ["Hashable", "Iterable", "Iterator", + "Sized", "Container", "Callable", + "Set", "MutableSet", + "Mapping", "MutableMapping", + "MappingView", "KeysView", "ItemsView", "ValuesView", + "Sequence", "MutableSequence", + ] + +### ONE-TRICK PONIES ### + +def _hasattr(C, attr): + try: + return any(attr in B.__dict__ for B in C.__mro__) + except AttributeError: + # Old-style class + return hasattr(C, attr) + + +class Hashable: + __metaclass__ = ABCMeta + + @abstractmethod + def __hash__(self): + return 0 + + @classmethod + def __subclasshook__(cls, C): + if cls is Hashable: + try: + for B in C.__mro__: + if "__hash__" in B.__dict__: + if B.__dict__["__hash__"]: + return True + break + except AttributeError: + # Old-style class + if getattr(C, "__hash__", None): + return True + return NotImplemented + + +class Iterable: + __metaclass__ = ABCMeta + + @abstractmethod + def __iter__(self): + while False: + yield None + + @classmethod + def __subclasshook__(cls, C): + if cls is Iterable: + if _hasattr(C, "__iter__"): + return True + return NotImplemented + +Iterable.register(str) + + +class Iterator(Iterable): + + @abstractmethod + def next(self): + 'Return the next item from the iterator. When exhausted, raise StopIteration' + raise StopIteration + + def __iter__(self): + return self + + @classmethod + def __subclasshook__(cls, C): + if cls is Iterator: + if _hasattr(C, "next") and _hasattr(C, "__iter__"): + return True + return NotImplemented + + +class Sized: + __metaclass__ = ABCMeta + + @abstractmethod + def __len__(self): + return 0 + + @classmethod + def __subclasshook__(cls, C): + if cls is Sized: + if _hasattr(C, "__len__"): + return True + return NotImplemented + + +class Container: + __metaclass__ = ABCMeta + + @abstractmethod + def __contains__(self, x): + return False + + @classmethod + def __subclasshook__(cls, C): + if cls is Container: + if _hasattr(C, "__contains__"): + return True + return NotImplemented + + +class Callable: + __metaclass__ = ABCMeta + + @abstractmethod + def __call__(self, *args, **kwds): + return False + + @classmethod + def __subclasshook__(cls, C): + if cls is Callable: + if _hasattr(C, "__call__"): + return True + return NotImplemented + + +### SETS ### + + +class Set(Sized, Iterable, Container): + """A set is a finite, iterable container. + + This class provides concrete generic implementations of all + methods except for __contains__, __iter__ and __len__. + + To override the comparisons (presumably for speed, as the + semantics are fixed), redefine __le__ and __ge__, + then the other operations will automatically follow suit. + """ + + def __le__(self, other): + if not isinstance(other, Set): + return NotImplemented + if len(self) > len(other): + return False + for elem in self: + if elem not in other: + return False + return True + + def __lt__(self, other): + if not isinstance(other, Set): + return NotImplemented + return len(self) < len(other) and self.__le__(other) + + def __gt__(self, other): + if not isinstance(other, Set): + return NotImplemented + return len(self) > len(other) and self.__ge__(other) + + def __ge__(self, other): + if not isinstance(other, Set): + return NotImplemented + if len(self) < len(other): + return False + for elem in other: + if elem not in self: + return False + return True + + def __eq__(self, other): + if not isinstance(other, Set): + return NotImplemented + return len(self) == len(other) and self.__le__(other) + + def __ne__(self, other): + return not (self == other) + + @classmethod + def _from_iterable(cls, it): + '''Construct an instance of the class from any iterable input. + + Must override this method if the class constructor signature + does not accept an iterable for an input. + ''' + return cls(it) + + def __and__(self, other): + if not isinstance(other, Iterable): + return NotImplemented + return self._from_iterable(value for value in other if value in self) + + __rand__ = __and__ + + def isdisjoint(self, other): + 'Return True if two sets have a null intersection.' + for value in other: + if value in self: + return False + return True + + def __or__(self, other): + if not isinstance(other, Iterable): + return NotImplemented + chain = (e for s in (self, other) for e in s) + return self._from_iterable(chain) + + __ror__ = __or__ + + def __sub__(self, other): + if not isinstance(other, Set): + if not isinstance(other, Iterable): + return NotImplemented + other = self._from_iterable(other) + return self._from_iterable(value for value in self + if value not in other) + + def __rsub__(self, other): + if not isinstance(other, Set): + if not isinstance(other, Iterable): + return NotImplemented + other = self._from_iterable(other) + return self._from_iterable(value for value in other + if value not in self) + + def __xor__(self, other): + if not isinstance(other, Set): + if not isinstance(other, Iterable): + return NotImplemented + other = self._from_iterable(other) + return (self - other) | (other - self) + + __rxor__ = __xor__ + + # Sets are not hashable by default, but subclasses can change this + __hash__ = None + + def _hash(self): + """Compute the hash value of a set. + + Note that we don't define __hash__: not all sets are hashable. + But if you define a hashable set type, its __hash__ should + call this function. + + This must be compatible __eq__. + + All sets ought to compare equal if they contain the same + elements, regardless of how they are implemented, and + regardless of the order of the elements; so there's not much + freedom for __eq__ or __hash__. We match the algorithm used + by the built-in frozenset type. + """ + MAX = sys.maxint + MASK = 2 * MAX + 1 + n = len(self) + h = 1927868237 * (n + 1) + h &= MASK + for x in self: + hx = hash(x) + h ^= (hx ^ (hx << 16) ^ 89869747) * 3644798167 + h &= MASK + h = h * 69069 + 907133923 + h &= MASK + if h > MAX: + h -= MASK + 1 + if h == -1: + h = 590923713 + return h + +Set.register(frozenset) + + +class MutableSet(Set): + """A mutable set is a finite, iterable container. + + This class provides concrete generic implementations of all + methods except for __contains__, __iter__, __len__, + add(), and discard(). + + To override the comparisons (presumably for speed, as the + semantics are fixed), all you have to do is redefine __le__ and + then the other operations will automatically follow suit. + """ + + @abstractmethod + def add(self, value): + """Add an element.""" + raise NotImplementedError + + @abstractmethod + def discard(self, value): + """Remove an element. Do not raise an exception if absent.""" + raise NotImplementedError + + def remove(self, value): + """Remove an element. If not a member, raise a KeyError.""" + if value not in self: + raise KeyError(value) + self.discard(value) + + def pop(self): + """Return the popped value. Raise KeyError if empty.""" + it = iter(self) + try: + value = next(it) + except StopIteration: + raise KeyError + self.discard(value) + return value + + def clear(self): + """This is slow (creates N new iterators!) but effective.""" + try: + while True: + self.pop() + except KeyError: + pass + + def __ior__(self, it): + for value in it: + self.add(value) + return self + + def __iand__(self, it): + for value in (self - it): + self.discard(value) + return self + + def __ixor__(self, it): + if it is self: + self.clear() + else: + if not isinstance(it, Set): + it = self._from_iterable(it) + for value in it: + if value in self: + self.discard(value) + else: + self.add(value) + return self + + def __isub__(self, it): + if it is self: + self.clear() + else: + for value in it: + self.discard(value) + return self + +MutableSet.register(set) + + +### MAPPINGS ### + + +class Mapping(Sized, Iterable, Container): + + """A Mapping is a generic container for associating key/value + pairs. + + This class provides concrete generic implementations of all + methods except for __getitem__, __iter__, and __len__. + + """ + + @abstractmethod + def __getitem__(self, key): + raise KeyError + + def get(self, key, default=None): + 'D.get(k[,d]) -> D[k] if k in D, else d. d defaults to None.' + try: + return self[key] + except KeyError: + return default + + def __contains__(self, key): + try: + self[key] + except KeyError: + return False + else: + return True + + def iterkeys(self): + 'D.iterkeys() -> an iterator over the keys of D' + return iter(self) + + def itervalues(self): + 'D.itervalues() -> an iterator over the values of D' + for key in self: + yield self[key] + + def iteritems(self): + 'D.iteritems() -> an iterator over the (key, value) items of D' + for key in self: + yield (key, self[key]) + + def keys(self): + "D.keys() -> list of D's keys" + return list(self) + + def items(self): + "D.items() -> list of D's (key, value) pairs, as 2-tuples" + return [(key, self[key]) for key in self] + + def values(self): + "D.values() -> list of D's values" + return [self[key] for key in self] + + # Mappings are not hashable by default, but subclasses can change this + __hash__ = None + + def __eq__(self, other): + if not isinstance(other, Mapping): + return NotImplemented + return dict(self.items()) == dict(other.items()) + + def __ne__(self, other): + return not (self == other) + +class MappingView(Sized): + + def __init__(self, mapping): + self._mapping = mapping + + def __len__(self): + return len(self._mapping) + + def __repr__(self): + return '{0.__class__.__name__}({0._mapping!r})'.format(self) + + +class KeysView(MappingView, Set): + + @classmethod + def _from_iterable(self, it): + return set(it) + + def __contains__(self, key): + return key in self._mapping + + def __iter__(self): + for key in self._mapping: + yield key + +KeysView.register(type({}.viewkeys())) + +class ItemsView(MappingView, Set): + + @classmethod + def _from_iterable(self, it): + return set(it) + + def __contains__(self, item): + key, value = item + try: + v = self._mapping[key] + except KeyError: + return False + else: + return v == value + + def __iter__(self): + for key in self._mapping: + yield (key, self._mapping[key]) + +ItemsView.register(type({}.viewitems())) + +class ValuesView(MappingView): + + def __contains__(self, value): + for key in self._mapping: + if value == self._mapping[key]: + return True + return False + + def __iter__(self): + for key in self._mapping: + yield self._mapping[key] + +ValuesView.register(type({}.viewvalues())) + +class MutableMapping(Mapping): + + """A MutableMapping is a generic container for associating + key/value pairs. + + This class provides concrete generic implementations of all + methods except for __getitem__, __setitem__, __delitem__, + __iter__, and __len__. + + """ + + @abstractmethod + def __setitem__(self, key, value): + raise KeyError + + @abstractmethod + def __delitem__(self, key): + raise KeyError + + __marker = object() + + def pop(self, key, default=__marker): + '''D.pop(k[,d]) -> v, remove specified key and return the corresponding value. + If key is not found, d is returned if given, otherwise KeyError is raised. + ''' + try: + value = self[key] + except KeyError: + if default is self.__marker: + raise + return default + else: + del self[key] + return value + + def popitem(self): + '''D.popitem() -> (k, v), remove and return some (key, value) pair + as a 2-tuple; but raise KeyError if D is empty. + ''' + try: + key = next(iter(self)) + except StopIteration: + raise KeyError + value = self[key] + del self[key] + return key, value + + def clear(self): + 'D.clear() -> None. Remove all items from D.' + try: + while True: + self.popitem() + except KeyError: + pass + + def update(*args, **kwds): + ''' D.update([E, ]**F) -> None. Update D from mapping/iterable E and F. + If E present and has a .keys() method, does: for k in E: D[k] = E[k] + If E present and lacks .keys() method, does: for (k, v) in E: D[k] = v + In either case, this is followed by: for k, v in F.items(): D[k] = v + ''' + if not args: + raise TypeError("descriptor 'update' of 'MutableMapping' object " + "needs an argument") + self = args[0] + args = args[1:] + if len(args) > 1: + raise TypeError('update expected at most 1 arguments, got %d' % + len(args)) + if args: + other = args[0] + if isinstance(other, Mapping): + for key in other: + self[key] = other[key] + elif hasattr(other, "keys"): + for key in other.keys(): + self[key] = other[key] + else: + for key, value in other: + self[key] = value + for key, value in kwds.items(): + self[key] = value + + def setdefault(self, key, default=None): + 'D.setdefault(k[,d]) -> D.get(k,d), also set D[k]=d if k not in D' + try: + return self[key] + except KeyError: + self[key] = default + return default + +MutableMapping.register(dict) + + +### SEQUENCES ### + + +class Sequence(Sized, Iterable, Container): + """All the operations on a read-only sequence. + + Concrete subclasses must override __new__ or __init__, + __getitem__, and __len__. + """ + + @abstractmethod + def __getitem__(self, index): + raise IndexError + + def __iter__(self): + i = 0 + try: + while True: + v = self[i] + yield v + i += 1 + except IndexError: + return + + def __contains__(self, value): + for v in self: + if v == value: + return True + return False + + def __reversed__(self): + for i in reversed(range(len(self))): + yield self[i] + + def index(self, value): + '''S.index(value) -> integer -- return first index of value. + Raises ValueError if the value is not present. + ''' + for i, v in enumerate(self): + if v == value: + return i + raise ValueError + + def count(self, value): + 'S.count(value) -> integer -- return number of occurrences of value' + return sum(1 for v in self if v == value) + +Sequence.register(tuple) +Sequence.register(basestring) +Sequence.register(buffer) +Sequence.register(xrange) + + +class MutableSequence(Sequence): + + """All the operations on a read-only sequence. + + Concrete subclasses must provide __new__ or __init__, + __getitem__, __setitem__, __delitem__, __len__, and insert(). + + """ + + @abstractmethod + def __setitem__(self, index, value): + raise IndexError + + @abstractmethod + def __delitem__(self, index): + raise IndexError + + @abstractmethod + def insert(self, index, value): + 'S.insert(index, object) -- insert object before index' + raise IndexError + + def append(self, value): + 'S.append(object) -- append object to the end of the sequence' + self.insert(len(self), value) + + def reverse(self): + 'S.reverse() -- reverse *IN PLACE*' + n = len(self) + for i in range(n//2): + self[i], self[n-i-1] = self[n-i-1], self[i] + + def extend(self, values): + 'S.extend(iterable) -- extend sequence by appending elements from the iterable' + for v in values: + self.append(v) + + def pop(self, index=-1): + '''S.pop([index]) -> item -- remove and return item at index (default last). + Raise IndexError if list is empty or index is out of range. + ''' + v = self[index] + del self[index] + return v + + def remove(self, value): + '''S.remove(value) -- remove first occurrence of value. + Raise ValueError if the value is not present. + ''' + del self[self.index(value)] + + def __iadd__(self, values): + self.extend(values) + return self + +MutableSequence.register(list) diff --git a/_weakrefset.py b/_weakrefset.py new file mode 100644 index 0000000..627959b --- /dev/null +++ b/_weakrefset.py @@ -0,0 +1,204 @@ +# Access WeakSet through the weakref module. +# This code is separated-out because it is needed +# by abc.py to load everything else at startup. + +from _weakref import ref + +__all__ = ['WeakSet'] + + +class _IterationGuard(object): + # This context manager registers itself in the current iterators of the + # weak container, such as to delay all removals until the context manager + # exits. + # This technique should be relatively thread-safe (since sets are). + + def __init__(self, weakcontainer): + # Don't create cycles + self.weakcontainer = ref(weakcontainer) + + def __enter__(self): + w = self.weakcontainer() + if w is not None: + w._iterating.add(self) + return self + + def __exit__(self, e, t, b): + w = self.weakcontainer() + if w is not None: + s = w._iterating + s.remove(self) + if not s: + w._commit_removals() + + +class WeakSet(object): + def __init__(self, data=None): + self.data = set() + def _remove(item, selfref=ref(self)): + self = selfref() + if self is not None: + if self._iterating: + self._pending_removals.append(item) + else: + self.data.discard(item) + self._remove = _remove + # A list of keys to be removed + self._pending_removals = [] + self._iterating = set() + if data is not None: + self.update(data) + + def _commit_removals(self): + l = self._pending_removals + discard = self.data.discard + while l: + discard(l.pop()) + + def __iter__(self): + with _IterationGuard(self): + for itemref in self.data: + item = itemref() + if item is not None: + # Caveat: the iterator will keep a strong reference to + # `item` until it is resumed or closed. + yield item + + def __len__(self): + return len(self.data) - len(self._pending_removals) + + def __contains__(self, item): + try: + wr = ref(item) + except TypeError: + return False + return wr in self.data + + def __reduce__(self): + return (self.__class__, (list(self),), + getattr(self, '__dict__', None)) + + __hash__ = None + + def add(self, item): + if self._pending_removals: + self._commit_removals() + self.data.add(ref(item, self._remove)) + + def clear(self): + if self._pending_removals: + self._commit_removals() + self.data.clear() + + def copy(self): + return self.__class__(self) + + def pop(self): + if self._pending_removals: + self._commit_removals() + while True: + try: + itemref = self.data.pop() + except KeyError: + raise KeyError('pop from empty WeakSet') + item = itemref() + if item is not None: + return item + + def remove(self, item): + if self._pending_removals: + self._commit_removals() + self.data.remove(ref(item)) + + def discard(self, item): + if self._pending_removals: + self._commit_removals() + self.data.discard(ref(item)) + + def update(self, other): + if self._pending_removals: + self._commit_removals() + for element in other: + self.add(element) + + def __ior__(self, other): + self.update(other) + return self + + def difference(self, other): + newset = self.copy() + newset.difference_update(other) + return newset + __sub__ = difference + + def difference_update(self, other): + self.__isub__(other) + def __isub__(self, other): + if self._pending_removals: + self._commit_removals() + if self is other: + self.data.clear() + else: + self.data.difference_update(ref(item) for item in other) + return self + + def intersection(self, other): + return self.__class__(item for item in other if item in self) + __and__ = intersection + + def intersection_update(self, other): + self.__iand__(other) + def __iand__(self, other): + if self._pending_removals: + self._commit_removals() + self.data.intersection_update(ref(item) for item in other) + return self + + def issubset(self, other): + return self.data.issubset(ref(item) for item in other) + __le__ = issubset + + def __lt__(self, other): + return self.data < set(ref(item) for item in other) + + def issuperset(self, other): + return self.data.issuperset(ref(item) for item in other) + __ge__ = issuperset + + def __gt__(self, other): + return self.data > set(ref(item) for item in other) + + def __eq__(self, other): + if not isinstance(other, self.__class__): + return NotImplemented + return self.data == set(ref(item) for item in other) + + def __ne__(self, other): + opposite = self.__eq__(other) + if opposite is NotImplemented: + return NotImplemented + return not opposite + + def symmetric_difference(self, other): + newset = self.copy() + newset.symmetric_difference_update(other) + return newset + __xor__ = symmetric_difference + + def symmetric_difference_update(self, other): + self.__ixor__(other) + def __ixor__(self, other): + if self._pending_removals: + self._commit_removals() + if self is other: + self.data.clear() + else: + self.data.symmetric_difference_update(ref(item, self._remove) for item in other) + return self + + def union(self, other): + return self.__class__(e for s in (self, other) for e in s) + __or__ = union + + def isdisjoint(self, other): + return len(self.intersection(other)) == 0 diff --git a/abc.py b/abc.py new file mode 100644 index 0000000..02e48a1 --- /dev/null +++ b/abc.py @@ -0,0 +1,185 @@ +# Copyright 2007 Google, Inc. All Rights Reserved. +# Licensed to PSF under a Contributor Agreement. + +"""Abstract Base Classes (ABCs) according to PEP 3119.""" + +import types + +from _weakrefset import WeakSet + +# Instance of old-style class +class _C: pass +_InstanceType = type(_C()) + + +def abstractmethod(funcobj): + """A decorator indicating abstract methods. + + Requires that the metaclass is ABCMeta or derived from it. A + class that has a metaclass derived from ABCMeta cannot be + instantiated unless all of its abstract methods are overridden. + The abstract methods can be called using any of the normal + 'super' call mechanisms. + + Usage: + + class C: + __metaclass__ = ABCMeta + @abstractmethod + def my_abstract_method(self, ...): + ... + """ + funcobj.__isabstractmethod__ = True + return funcobj + + +class abstractproperty(property): + """A decorator indicating abstract properties. + + Requires that the metaclass is ABCMeta or derived from it. A + class that has a metaclass derived from ABCMeta cannot be + instantiated unless all of its abstract properties are overridden. + The abstract properties can be called using any of the normal + 'super' call mechanisms. + + Usage: + + class C: + __metaclass__ = ABCMeta + @abstractproperty + def my_abstract_property(self): + ... + + This defines a read-only property; you can also define a read-write + abstract property using the 'long' form of property declaration: + + class C: + __metaclass__ = ABCMeta + def getx(self): ... + def setx(self, value): ... + x = abstractproperty(getx, setx) + """ + __isabstractmethod__ = True + + +class ABCMeta(type): + + """Metaclass for defining Abstract Base Classes (ABCs). + + Use this metaclass to create an ABC. An ABC can be subclassed + directly, and then acts as a mix-in class. You can also register + unrelated concrete classes (even built-in classes) and unrelated + ABCs as 'virtual subclasses' -- these and their descendants will + be considered subclasses of the registering ABC by the built-in + issubclass() function, but the registering ABC won't show up in + their MRO (Method Resolution Order) nor will method + implementations defined by the registering ABC be callable (not + even via super()). + + """ + + # A global counter that is incremented each time a class is + # registered as a virtual subclass of anything. It forces the + # negative cache to be cleared before its next use. + _abc_invalidation_counter = 0 + + def __new__(mcls, name, bases, namespace): + cls = super(ABCMeta, mcls).__new__(mcls, name, bases, namespace) + # Compute set of abstract method names + abstracts = set(name + for name, value in namespace.items() + if getattr(value, "__isabstractmethod__", False)) + for base in bases: + for name in getattr(base, "__abstractmethods__", set()): + value = getattr(cls, name, None) + if getattr(value, "__isabstractmethod__", False): + abstracts.add(name) + cls.__abstractmethods__ = frozenset(abstracts) + # Set up inheritance registry + cls._abc_registry = WeakSet() + cls._abc_cache = WeakSet() + cls._abc_negative_cache = WeakSet() + cls._abc_negative_cache_version = ABCMeta._abc_invalidation_counter + return cls + + def register(cls, subclass): + """Register a virtual subclass of an ABC.""" + if not isinstance(subclass, (type, types.ClassType)): + raise TypeError("Can only register classes") + if issubclass(subclass, cls): + return # Already a subclass + # Subtle: test for cycles *after* testing for "already a subclass"; + # this means we allow X.register(X) and interpret it as a no-op. + if issubclass(cls, subclass): + # This would create a cycle, which is bad for the algorithm below + raise RuntimeError("Refusing to create an inheritance cycle") + cls._abc_registry.add(subclass) + ABCMeta._abc_invalidation_counter += 1 # Invalidate negative cache + + def _dump_registry(cls, file=None): + """Debug helper to print the ABC registry.""" + print >> file, "Class: %s.%s" % (cls.__module__, cls.__name__) + print >> file, "Inv.counter: %s" % ABCMeta._abc_invalidation_counter + for name in sorted(cls.__dict__.keys()): + if name.startswith("_abc_"): + value = getattr(cls, name) + print >> file, "%s: %r" % (name, value) + + def __instancecheck__(cls, instance): + """Override for isinstance(instance, cls).""" + # Inline the cache checking when it's simple. + subclass = getattr(instance, '__class__', None) + if subclass is not None and subclass in cls._abc_cache: + return True + subtype = type(instance) + # Old-style instances + if subtype is _InstanceType: + subtype = subclass + if subtype is subclass or subclass is None: + if (cls._abc_negative_cache_version == + ABCMeta._abc_invalidation_counter and + subtype in cls._abc_negative_cache): + return False + # Fall back to the subclass check. + return cls.__subclasscheck__(subtype) + return (cls.__subclasscheck__(subclass) or + cls.__subclasscheck__(subtype)) + + def __subclasscheck__(cls, subclass): + """Override for issubclass(subclass, cls).""" + # Check cache + if subclass in cls._abc_cache: + return True + # Check negative cache; may have to invalidate + if cls._abc_negative_cache_version < ABCMeta._abc_invalidation_counter: + # Invalidate the negative cache + cls._abc_negative_cache = WeakSet() + cls._abc_negative_cache_version = ABCMeta._abc_invalidation_counter + elif subclass in cls._abc_negative_cache: + return False + # Check the subclass hook + ok = cls.__subclasshook__(subclass) + if ok is not NotImplemented: + assert isinstance(ok, bool) + if ok: + cls._abc_cache.add(subclass) + else: + cls._abc_negative_cache.add(subclass) + return ok + # Check if it's a direct subclass + if cls in getattr(subclass, '__mro__', ()): + cls._abc_cache.add(subclass) + return True + # Check if it's a subclass of a registered class (recursive) + for rcls in cls._abc_registry: + if issubclass(subclass, rcls): + cls._abc_cache.add(subclass) + return True + # Check if it's a subclass of a subclass (recursive) + for scls in cls.__subclasses__(): + if issubclass(subclass, scls): + cls._abc_cache.add(subclass) + return True + # No dice; update negative cache + cls._abc_negative_cache.add(subclass) + return False diff --git a/base64.py b/base64.py new file mode 100644 index 0000000..82c112c --- /dev/null +++ b/base64.py @@ -0,0 +1,364 @@ +#! /usr/bin/env python + +"""RFC 3548: Base16, Base32, Base64 Data Encodings""" + +# Modified 04-Oct-1995 by Jack Jansen to use binascii module +# Modified 30-Dec-2003 by Barry Warsaw to add full RFC 3548 support + +import re +import struct +import string +import binascii + + +__all__ = [ + # Legacy interface exports traditional RFC 1521 Base64 encodings + 'encode', 'decode', 'encodestring', 'decodestring', + # Generalized interface for other encodings + 'b64encode', 'b64decode', 'b32encode', 'b32decode', + 'b16encode', 'b16decode', + # Standard Base64 encoding + 'standard_b64encode', 'standard_b64decode', + # Some common Base64 alternatives. As referenced by RFC 3458, see thread + # starting at: + # + # http://zgp.org/pipermail/p2p-hackers/2001-September/000316.html + 'urlsafe_b64encode', 'urlsafe_b64decode', + ] + +_translation = [chr(_x) for _x in range(256)] +EMPTYSTRING = '' + + +def _translate(s, altchars): + translation = _translation[:] + for k, v in altchars.items(): + translation[ord(k)] = v + return s.translate(''.join(translation)) + + + +# Base64 encoding/decoding uses binascii + +def b64encode(s, altchars=None): + """Encode a string using Base64. + + s is the string to encode. Optional altchars must be a string of at least + length 2 (additional characters are ignored) which specifies an + alternative alphabet for the '+' and '/' characters. This allows an + application to e.g. generate url or filesystem safe Base64 strings. + + The encoded string is returned. + """ + # Strip off the trailing newline + encoded = binascii.b2a_base64(s)[:-1] + if altchars is not None: + return encoded.translate(string.maketrans(b'+/', altchars[:2])) + return encoded + + +def b64decode(s, altchars=None): + """Decode a Base64 encoded string. + + s is the string to decode. Optional altchars must be a string of at least + length 2 (additional characters are ignored) which specifies the + alternative alphabet used instead of the '+' and '/' characters. + + The decoded string is returned. A TypeError is raised if s were + incorrectly padded or if there are non-alphabet characters present in the + string. + """ + if altchars is not None: + s = s.translate(string.maketrans(altchars[:2], '+/')) + try: + return binascii.a2b_base64(s) + except binascii.Error, msg: + # Transform this exception for consistency + raise TypeError(msg) + + +def standard_b64encode(s): + """Encode a string using the standard Base64 alphabet. + + s is the string to encode. The encoded string is returned. + """ + return b64encode(s) + +def standard_b64decode(s): + """Decode a string encoded with the standard Base64 alphabet. + + s is the string to decode. The decoded string is returned. A TypeError + is raised if the string is incorrectly padded or if there are non-alphabet + characters present in the string. + """ + return b64decode(s) + +_urlsafe_encode_translation = string.maketrans(b'+/', b'-_') +_urlsafe_decode_translation = string.maketrans(b'-_', b'+/') + +def urlsafe_b64encode(s): + """Encode a string using a url-safe Base64 alphabet. + + s is the string to encode. The encoded string is returned. The alphabet + uses '-' instead of '+' and '_' instead of '/'. + """ + return b64encode(s).translate(_urlsafe_encode_translation) + +def urlsafe_b64decode(s): + """Decode a string encoded with the standard Base64 alphabet. + + s is the string to decode. The decoded string is returned. A TypeError + is raised if the string is incorrectly padded or if there are non-alphabet + characters present in the string. + + The alphabet uses '-' instead of '+' and '_' instead of '/'. + """ + return b64decode(s.translate(_urlsafe_decode_translation)) + + + +# Base32 encoding/decoding must be done in Python +_b32alphabet = { + 0: 'A', 9: 'J', 18: 'S', 27: '3', + 1: 'B', 10: 'K', 19: 'T', 28: '4', + 2: 'C', 11: 'L', 20: 'U', 29: '5', + 3: 'D', 12: 'M', 21: 'V', 30: '6', + 4: 'E', 13: 'N', 22: 'W', 31: '7', + 5: 'F', 14: 'O', 23: 'X', + 6: 'G', 15: 'P', 24: 'Y', + 7: 'H', 16: 'Q', 25: 'Z', + 8: 'I', 17: 'R', 26: '2', + } + +_b32tab = _b32alphabet.items() +_b32tab.sort() +_b32tab = [v for k, v in _b32tab] +_b32rev = dict([(v, long(k)) for k, v in _b32alphabet.items()]) + + +def b32encode(s): + """Encode a string using Base32. + + s is the string to encode. The encoded string is returned. + """ + parts = [] + quanta, leftover = divmod(len(s), 5) + # Pad the last quantum with zero bits if necessary + if leftover: + s += ('\0' * (5 - leftover)) + quanta += 1 + for i in range(quanta): + # c1 and c2 are 16 bits wide, c3 is 8 bits wide. The intent of this + # code is to process the 40 bits in units of 5 bits. So we take the 1 + # leftover bit of c1 and tack it onto c2. Then we take the 2 leftover + # bits of c2 and tack them onto c3. The shifts and masks are intended + # to give us values of exactly 5 bits in width. + c1, c2, c3 = struct.unpack('!HHB', s[i*5:(i+1)*5]) + c2 += (c1 & 1) << 16 # 17 bits wide + c3 += (c2 & 3) << 8 # 10 bits wide + parts.extend([_b32tab[c1 >> 11], # bits 1 - 5 + _b32tab[(c1 >> 6) & 0x1f], # bits 6 - 10 + _b32tab[(c1 >> 1) & 0x1f], # bits 11 - 15 + _b32tab[c2 >> 12], # bits 16 - 20 (1 - 5) + _b32tab[(c2 >> 7) & 0x1f], # bits 21 - 25 (6 - 10) + _b32tab[(c2 >> 2) & 0x1f], # bits 26 - 30 (11 - 15) + _b32tab[c3 >> 5], # bits 31 - 35 (1 - 5) + _b32tab[c3 & 0x1f], # bits 36 - 40 (1 - 5) + ]) + encoded = EMPTYSTRING.join(parts) + # Adjust for any leftover partial quanta + if leftover == 1: + return encoded[:-6] + '======' + elif leftover == 2: + return encoded[:-4] + '====' + elif leftover == 3: + return encoded[:-3] + '===' + elif leftover == 4: + return encoded[:-1] + '=' + return encoded + + +def b32decode(s, casefold=False, map01=None): + """Decode a Base32 encoded string. + + s is the string to decode. Optional casefold is a flag specifying whether + a lowercase alphabet is acceptable as input. For security purposes, the + default is False. + + RFC 3548 allows for optional mapping of the digit 0 (zero) to the letter O + (oh), and for optional mapping of the digit 1 (one) to either the letter I + (eye) or letter L (el). The optional argument map01 when not None, + specifies which letter the digit 1 should be mapped to (when map01 is not + None, the digit 0 is always mapped to the letter O). For security + purposes the default is None, so that 0 and 1 are not allowed in the + input. + + The decoded string is returned. A TypeError is raised if s were + incorrectly padded or if there are non-alphabet characters present in the + string. + """ + quanta, leftover = divmod(len(s), 8) + if leftover: + raise TypeError('Incorrect padding') + # Handle section 2.4 zero and one mapping. The flag map01 will be either + # False, or the character to map the digit 1 (one) to. It should be + # either L (el) or I (eye). + if map01: + s = s.translate(string.maketrans(b'01', b'O' + map01)) + if casefold: + s = s.upper() + # Strip off pad characters from the right. We need to count the pad + # characters because this will tell us how many null bytes to remove from + # the end of the decoded string. + padchars = 0 + mo = re.search('(?P[=]*)$', s) + if mo: + padchars = len(mo.group('pad')) + if padchars > 0: + s = s[:-padchars] + # Now decode the full quanta + parts = [] + acc = 0 + shift = 35 + for c in s: + val = _b32rev.get(c) + if val is None: + raise TypeError('Non-base32 digit found') + acc += _b32rev[c] << shift + shift -= 5 + if shift < 0: + parts.append(binascii.unhexlify('%010x' % acc)) + acc = 0 + shift = 35 + # Process the last, partial quanta + last = binascii.unhexlify('%010x' % acc) + if padchars == 0: + last = '' # No characters + elif padchars == 1: + last = last[:-1] + elif padchars == 3: + last = last[:-2] + elif padchars == 4: + last = last[:-3] + elif padchars == 6: + last = last[:-4] + else: + raise TypeError('Incorrect padding') + parts.append(last) + return EMPTYSTRING.join(parts) + + + +# RFC 3548, Base 16 Alphabet specifies uppercase, but hexlify() returns +# lowercase. The RFC also recommends against accepting input case +# insensitively. +def b16encode(s): + """Encode a string using Base16. + + s is the string to encode. The encoded string is returned. + """ + return binascii.hexlify(s).upper() + + +def b16decode(s, casefold=False): + """Decode a Base16 encoded string. + + s is the string to decode. Optional casefold is a flag specifying whether + a lowercase alphabet is acceptable as input. For security purposes, the + default is False. + + The decoded string is returned. A TypeError is raised if s were + incorrectly padded or if there are non-alphabet characters present in the + string. + """ + if casefold: + s = s.upper() + if re.search('[^0-9A-F]', s): + raise TypeError('Non-base16 digit found') + return binascii.unhexlify(s) + + + +# Legacy interface. This code could be cleaned up since I don't believe +# binascii has any line length limitations. It just doesn't seem worth it +# though. + +MAXLINESIZE = 76 # Excluding the CRLF +MAXBINSIZE = (MAXLINESIZE//4)*3 + +def encode(input, output): + """Encode a file.""" + while True: + s = input.read(MAXBINSIZE) + if not s: + break + while len(s) < MAXBINSIZE: + ns = input.read(MAXBINSIZE-len(s)) + if not ns: + break + s += ns + line = binascii.b2a_base64(s) + output.write(line) + + +def decode(input, output): + """Decode a file.""" + while True: + line = input.readline() + if not line: + break + s = binascii.a2b_base64(line) + output.write(s) + + +def encodestring(s): + """Encode a string into multiple lines of base-64 data.""" + pieces = [] + for i in range(0, len(s), MAXBINSIZE): + chunk = s[i : i + MAXBINSIZE] + pieces.append(binascii.b2a_base64(chunk)) + return "".join(pieces) + + +def decodestring(s): + """Decode a string.""" + return binascii.a2b_base64(s) + + + +# Useable as a script... +def test(): + """Small test program""" + import sys, getopt + try: + opts, args = getopt.getopt(sys.argv[1:], 'deut') + except getopt.error, msg: + sys.stdout = sys.stderr + print msg + print """usage: %s [-d|-e|-u|-t] [file|-] + -d, -u: decode + -e: encode (default) + -t: encode and decode string 'Aladdin:open sesame'"""%sys.argv[0] + sys.exit(2) + func = encode + for o, a in opts: + if o == '-e': func = encode + if o == '-d': func = decode + if o == '-u': func = decode + if o == '-t': test1(); return + if args and args[0] != '-': + with open(args[0], 'rb') as f: + func(f, sys.stdout) + else: + func(sys.stdin, sys.stdout) + + +def test1(): + s0 = "Aladdin:open sesame" + s1 = encodestring(s0) + s2 = decodestring(s1) + print s0, repr(s1), s2 + + +if __name__ == '__main__': + test() diff --git a/bisect.py b/bisect.py new file mode 100644 index 0000000..4a4d052 --- /dev/null +++ b/bisect.py @@ -0,0 +1,92 @@ +"""Bisection algorithms.""" + +def insort_right(a, x, lo=0, hi=None): + """Insert item x in list a, and keep it sorted assuming a is sorted. + + If x is already in a, insert it to the right of the rightmost x. + + Optional args lo (default 0) and hi (default len(a)) bound the + slice of a to be searched. + """ + + if lo < 0: + raise ValueError('lo must be non-negative') + if hi is None: + hi = len(a) + while lo < hi: + mid = (lo+hi)//2 + if x < a[mid]: hi = mid + else: lo = mid+1 + a.insert(lo, x) + +insort = insort_right # backward compatibility + +def bisect_right(a, x, lo=0, hi=None): + """Return the index where to insert item x in list a, assuming a is sorted. + + The return value i is such that all e in a[:i] have e <= x, and all e in + a[i:] have e > x. So if x already appears in the list, a.insert(x) will + insert just after the rightmost x already there. + + Optional args lo (default 0) and hi (default len(a)) bound the + slice of a to be searched. + """ + + if lo < 0: + raise ValueError('lo must be non-negative') + if hi is None: + hi = len(a) + while lo < hi: + mid = (lo+hi)//2 + if x < a[mid]: hi = mid + else: lo = mid+1 + return lo + +bisect = bisect_right # backward compatibility + +def insort_left(a, x, lo=0, hi=None): + """Insert item x in list a, and keep it sorted assuming a is sorted. + + If x is already in a, insert it to the left of the leftmost x. + + Optional args lo (default 0) and hi (default len(a)) bound the + slice of a to be searched. + """ + + if lo < 0: + raise ValueError('lo must be non-negative') + if hi is None: + hi = len(a) + while lo < hi: + mid = (lo+hi)//2 + if a[mid] < x: lo = mid+1 + else: hi = mid + a.insert(lo, x) + + +def bisect_left(a, x, lo=0, hi=None): + """Return the index where to insert item x in list a, assuming a is sorted. + + The return value i is such that all e in a[:i] have e < x, and all e in + a[i:] have e >= x. So if x already appears in the list, a.insert(x) will + insert just before the leftmost x already there. + + Optional args lo (default 0) and hi (default len(a)) bound the + slice of a to be searched. + """ + + if lo < 0: + raise ValueError('lo must be non-negative') + if hi is None: + hi = len(a) + while lo < hi: + mid = (lo+hi)//2 + if a[mid] < x: lo = mid+1 + else: hi = mid + return lo + +# Overwrite above definitions with a fast C implementation +try: + from _bisect import * +except ImportError: + pass diff --git a/collections.py b/collections.py new file mode 100644 index 0000000..1dcd233 --- /dev/null +++ b/collections.py @@ -0,0 +1,730 @@ +__all__ = ['Counter', 'deque', 'defaultdict', 'namedtuple', 'OrderedDict'] +# For bootstrapping reasons, the collection ABCs are defined in _abcoll.py. +# They should however be considered an integral part of collections.py. +from _abcoll import * +import _abcoll +__all__ += _abcoll.__all__ + +from _collections import deque, defaultdict +from operator import itemgetter as _itemgetter, eq as _eq +from keyword import iskeyword as _iskeyword +import sys as _sys +import heapq as _heapq +from itertools import repeat as _repeat, chain as _chain, starmap as _starmap +from itertools import imap as _imap + +try: + from thread import get_ident as _get_ident +except ImportError: + from dummy_thread import get_ident as _get_ident + + +################################################################################ +### OrderedDict +################################################################################ + +class OrderedDict(dict): + 'Dictionary that remembers insertion order' + # An inherited dict maps keys to values. + # The inherited dict provides __getitem__, __len__, __contains__, and get. + # The remaining methods are order-aware. + # Big-O running times for all methods are the same as regular dictionaries. + + # The internal self.__map dict maps keys to links in a doubly linked list. + # The circular doubly linked list starts and ends with a sentinel element. + # The sentinel element never gets deleted (this simplifies the algorithm). + # Each link is stored as a list of length three: [PREV, NEXT, KEY]. + + def __init__(*args, **kwds): + '''Initialize an ordered dictionary. The signature is the same as + regular dictionaries, but keyword arguments are not recommended because + their insertion order is arbitrary. + + ''' + if not args: + raise TypeError("descriptor '__init__' of 'OrderedDict' object " + "needs an argument") + self = args[0] + args = args[1:] + if len(args) > 1: + raise TypeError('expected at most 1 arguments, got %d' % len(args)) + try: + self.__root + except AttributeError: + self.__root = root = [] # sentinel node + root[:] = [root, root, None] + self.__map = {} + self.__update(*args, **kwds) + + def __setitem__(self, key, value, dict_setitem=dict.__setitem__): + 'od.__setitem__(i, y) <==> od[i]=y' + # Setting a new item creates a new link at the end of the linked list, + # and the inherited dictionary is updated with the new key/value pair. + if key not in self: + root = self.__root + last = root[0] + last[1] = root[0] = self.__map[key] = [last, root, key] + return dict_setitem(self, key, value) + + def __delitem__(self, key, dict_delitem=dict.__delitem__): + 'od.__delitem__(y) <==> del od[y]' + # Deleting an existing item uses self.__map to find the link which gets + # removed by updating the links in the predecessor and successor nodes. + dict_delitem(self, key) + link_prev, link_next, _ = self.__map.pop(key) + link_prev[1] = link_next # update link_prev[NEXT] + link_next[0] = link_prev # update link_next[PREV] + + def __iter__(self): + 'od.__iter__() <==> iter(od)' + # Traverse the linked list in order. + root = self.__root + curr = root[1] # start at the first node + while curr is not root: + yield curr[2] # yield the curr[KEY] + curr = curr[1] # move to next node + + def __reversed__(self): + 'od.__reversed__() <==> reversed(od)' + # Traverse the linked list in reverse order. + root = self.__root + curr = root[0] # start at the last node + while curr is not root: + yield curr[2] # yield the curr[KEY] + curr = curr[0] # move to previous node + + def clear(self): + 'od.clear() -> None. Remove all items from od.' + root = self.__root + root[:] = [root, root, None] + self.__map.clear() + dict.clear(self) + + # -- the following methods do not depend on the internal structure -- + + def keys(self): + 'od.keys() -> list of keys in od' + return list(self) + + def values(self): + 'od.values() -> list of values in od' + return [self[key] for key in self] + + def items(self): + 'od.items() -> list of (key, value) pairs in od' + return [(key, self[key]) for key in self] + + def iterkeys(self): + 'od.iterkeys() -> an iterator over the keys in od' + return iter(self) + + def itervalues(self): + 'od.itervalues -> an iterator over the values in od' + for k in self: + yield self[k] + + def iteritems(self): + 'od.iteritems -> an iterator over the (key, value) pairs in od' + for k in self: + yield (k, self[k]) + + update = MutableMapping.update + + __update = update # let subclasses override update without breaking __init__ + + __marker = object() + + def pop(self, key, default=__marker): + '''od.pop(k[,d]) -> v, remove specified key and return the corresponding + value. If key is not found, d is returned if given, otherwise KeyError + is raised. + + ''' + if key in self: + result = self[key] + del self[key] + return result + if default is self.__marker: + raise KeyError(key) + return default + + def setdefault(self, key, default=None): + 'od.setdefault(k[,d]) -> od.get(k,d), also set od[k]=d if k not in od' + if key in self: + return self[key] + self[key] = default + return default + + def popitem(self, last=True): + '''od.popitem() -> (k, v), return and remove a (key, value) pair. + Pairs are returned in LIFO order if last is true or FIFO order if false. + + ''' + if not self: + raise KeyError('dictionary is empty') + key = next(reversed(self) if last else iter(self)) + value = self.pop(key) + return key, value + + def __repr__(self, _repr_running={}): + 'od.__repr__() <==> repr(od)' + call_key = id(self), _get_ident() + if call_key in _repr_running: + return '...' + _repr_running[call_key] = 1 + try: + if not self: + return '%s()' % (self.__class__.__name__,) + return '%s(%r)' % (self.__class__.__name__, self.items()) + finally: + del _repr_running[call_key] + + def __reduce__(self): + 'Return state information for pickling' + items = [[k, self[k]] for k in self] + inst_dict = vars(self).copy() + for k in vars(OrderedDict()): + inst_dict.pop(k, None) + if inst_dict: + return (self.__class__, (items,), inst_dict) + return self.__class__, (items,) + + def copy(self): + 'od.copy() -> a shallow copy of od' + return self.__class__(self) + + @classmethod + def fromkeys(cls, iterable, value=None): + '''OD.fromkeys(S[, v]) -> New ordered dictionary with keys from S. + If not specified, the value defaults to None. + + ''' + self = cls() + for key in iterable: + self[key] = value + return self + + def __eq__(self, other): + '''od.__eq__(y) <==> od==y. Comparison to another OD is order-sensitive + while comparison to a regular mapping is order-insensitive. + + ''' + if isinstance(other, OrderedDict): + return dict.__eq__(self, other) and all(_imap(_eq, self, other)) + return dict.__eq__(self, other) + + def __ne__(self, other): + 'od.__ne__(y) <==> od!=y' + return not self == other + + # -- the following methods support python 3.x style dictionary views -- + + def viewkeys(self): + "od.viewkeys() -> a set-like object providing a view on od's keys" + return KeysView(self) + + def viewvalues(self): + "od.viewvalues() -> an object providing a view on od's values" + return ValuesView(self) + + def viewitems(self): + "od.viewitems() -> a set-like object providing a view on od's items" + return ItemsView(self) + + +################################################################################ +### namedtuple +################################################################################ + +_class_template = '''\ +class {typename}(tuple): + '{typename}({arg_list})' + + __slots__ = () + + _fields = {field_names!r} + + def __new__(_cls, {arg_list}): + 'Create new instance of {typename}({arg_list})' + return _tuple.__new__(_cls, ({arg_list})) + + @classmethod + def _make(cls, iterable, new=tuple.__new__, len=len): + 'Make a new {typename} object from a sequence or iterable' + result = new(cls, iterable) + if len(result) != {num_fields:d}: + raise TypeError('Expected {num_fields:d} arguments, got %d' % len(result)) + return result + + def __repr__(self): + 'Return a nicely formatted representation string' + return '{typename}({repr_fmt})' % self + + def _asdict(self): + 'Return a new OrderedDict which maps field names to their values' + return OrderedDict(zip(self._fields, self)) + + def _replace(_self, **kwds): + 'Return a new {typename} object replacing specified fields with new values' + result = _self._make(map(kwds.pop, {field_names!r}, _self)) + if kwds: + raise ValueError('Got unexpected field names: %r' % kwds.keys()) + return result + + def __getnewargs__(self): + 'Return self as a plain tuple. Used by copy and pickle.' + return tuple(self) + + __dict__ = _property(_asdict) + + def __getstate__(self): + 'Exclude the OrderedDict from pickling' + pass + +{field_defs} +''' + +_repr_template = '{name}=%r' + +_field_template = '''\ + {name} = _property(_itemgetter({index:d}), doc='Alias for field number {index:d}') +''' + +def namedtuple(typename, field_names, verbose=False, rename=False): + """Returns a new subclass of tuple with named fields. + + >>> Point = namedtuple('Point', ['x', 'y']) + >>> Point.__doc__ # docstring for the new class + 'Point(x, y)' + >>> p = Point(11, y=22) # instantiate with positional args or keywords + >>> p[0] + p[1] # indexable like a plain tuple + 33 + >>> x, y = p # unpack like a regular tuple + >>> x, y + (11, 22) + >>> p.x + p.y # fields also accessable by name + 33 + >>> d = p._asdict() # convert to a dictionary + >>> d['x'] + 11 + >>> Point(**d) # convert from a dictionary + Point(x=11, y=22) + >>> p._replace(x=100) # _replace() is like str.replace() but targets named fields + Point(x=100, y=22) + + """ + + # Validate the field names. At the user's option, either generate an error + # message or automatically replace the field name with a valid name. + if isinstance(field_names, basestring): + field_names = field_names.replace(',', ' ').split() + field_names = map(str, field_names) + typename = str(typename) + if rename: + seen = set() + for index, name in enumerate(field_names): + if (not all(c.isalnum() or c=='_' for c in name) + or _iskeyword(name) + or not name + or name[0].isdigit() + or name.startswith('_') + or name in seen): + field_names[index] = '_%d' % index + seen.add(name) + for name in [typename] + field_names: + if type(name) != str: + raise TypeError('Type names and field names must be strings') + if not all(c.isalnum() or c=='_' for c in name): + raise ValueError('Type names and field names can only contain ' + 'alphanumeric characters and underscores: %r' % name) + if _iskeyword(name): + raise ValueError('Type names and field names cannot be a ' + 'keyword: %r' % name) + if name[0].isdigit(): + raise ValueError('Type names and field names cannot start with ' + 'a number: %r' % name) + seen = set() + for name in field_names: + if name.startswith('_') and not rename: + raise ValueError('Field names cannot start with an underscore: ' + '%r' % name) + if name in seen: + raise ValueError('Encountered duplicate field name: %r' % name) + seen.add(name) + + # Fill-in the class template + class_definition = _class_template.format( + typename = typename, + field_names = tuple(field_names), + num_fields = len(field_names), + arg_list = repr(tuple(field_names)).replace("'", "")[1:-1], + repr_fmt = ', '.join(_repr_template.format(name=name) + for name in field_names), + field_defs = '\n'.join(_field_template.format(index=index, name=name) + for index, name in enumerate(field_names)) + ) + if verbose: + print class_definition + + # Execute the template string in a temporary namespace and support + # tracing utilities by setting a value for frame.f_globals['__name__'] + namespace = dict(_itemgetter=_itemgetter, __name__='namedtuple_%s' % typename, + OrderedDict=OrderedDict, _property=property, _tuple=tuple) + try: + exec class_definition in namespace + except SyntaxError as e: + raise SyntaxError(e.message + ':\n' + class_definition) + result = namespace[typename] + + # For pickling to work, the __module__ variable needs to be set to the frame + # where the named tuple is created. Bypass this step in environments where + # sys._getframe is not defined (Jython for example) or sys._getframe is not + # defined for arguments greater than 0 (IronPython). + try: + result.__module__ = _sys._getframe(1).f_globals.get('__name__', '__main__') + except (AttributeError, ValueError): + pass + + return result + + +######################################################################## +### Counter +######################################################################## + +class Counter(dict): + '''Dict subclass for counting hashable items. Sometimes called a bag + or multiset. Elements are stored as dictionary keys and their counts + are stored as dictionary values. + + >>> c = Counter('abcdeabcdabcaba') # count elements from a string + + >>> c.most_common(3) # three most common elements + [('a', 5), ('b', 4), ('c', 3)] + >>> sorted(c) # list all unique elements + ['a', 'b', 'c', 'd', 'e'] + >>> ''.join(sorted(c.elements())) # list elements with repetitions + 'aaaaabbbbcccdde' + >>> sum(c.values()) # total of all counts + 15 + + >>> c['a'] # count of letter 'a' + 5 + >>> for elem in 'shazam': # update counts from an iterable + ... c[elem] += 1 # by adding 1 to each element's count + >>> c['a'] # now there are seven 'a' + 7 + >>> del c['b'] # remove all 'b' + >>> c['b'] # now there are zero 'b' + 0 + + >>> d = Counter('simsalabim') # make another counter + >>> c.update(d) # add in the second counter + >>> c['a'] # now there are nine 'a' + 9 + + >>> c.clear() # empty the counter + >>> c + Counter() + + Note: If a count is set to zero or reduced to zero, it will remain + in the counter until the entry is deleted or the counter is cleared: + + >>> c = Counter('aaabbc') + >>> c['b'] -= 2 # reduce the count of 'b' by two + >>> c.most_common() # 'b' is still in, but its count is zero + [('a', 3), ('c', 1), ('b', 0)] + + ''' + # References: + # http://en.wikipedia.org/wiki/Multiset + # http://www.gnu.org/software/smalltalk/manual-base/html_node/Bag.html + # http://www.demo2s.com/Tutorial/Cpp/0380__set-multiset/Catalog0380__set-multiset.htm + # http://code.activestate.com/recipes/259174/ + # Knuth, TAOCP Vol. II section 4.6.3 + + def __init__(*args, **kwds): + '''Create a new, empty Counter object. And if given, count elements + from an input iterable. Or, initialize the count from another mapping + of elements to their counts. + + >>> c = Counter() # a new, empty counter + >>> c = Counter('gallahad') # a new counter from an iterable + >>> c = Counter({'a': 4, 'b': 2}) # a new counter from a mapping + >>> c = Counter(a=4, b=2) # a new counter from keyword args + + ''' + if not args: + raise TypeError("descriptor '__init__' of 'Counter' object " + "needs an argument") + self = args[0] + args = args[1:] + if len(args) > 1: + raise TypeError('expected at most 1 arguments, got %d' % len(args)) + super(Counter, self).__init__() + self.update(*args, **kwds) + + def __missing__(self, key): + 'The count of elements not in the Counter is zero.' + # Needed so that self[missing_item] does not raise KeyError + return 0 + + def most_common(self, n=None): + '''List the n most common elements and their counts from the most + common to the least. If n is None, then list all element counts. + + >>> Counter('abcdeabcdabcaba').most_common(3) + [('a', 5), ('b', 4), ('c', 3)] + + ''' + # Emulate Bag.sortedByCount from Smalltalk + if n is None: + return sorted(self.iteritems(), key=_itemgetter(1), reverse=True) + return _heapq.nlargest(n, self.iteritems(), key=_itemgetter(1)) + + def elements(self): + '''Iterator over elements repeating each as many times as its count. + + >>> c = Counter('ABCABC') + >>> sorted(c.elements()) + ['A', 'A', 'B', 'B', 'C', 'C'] + + # Knuth's example for prime factors of 1836: 2**2 * 3**3 * 17**1 + >>> prime_factors = Counter({2: 2, 3: 3, 17: 1}) + >>> product = 1 + >>> for factor in prime_factors.elements(): # loop over factors + ... product *= factor # and multiply them + >>> product + 1836 + + Note, if an element's count has been set to zero or is a negative + number, elements() will ignore it. + + ''' + # Emulate Bag.do from Smalltalk and Multiset.begin from C++. + return _chain.from_iterable(_starmap(_repeat, self.iteritems())) + + # Override dict methods where necessary + + @classmethod + def fromkeys(cls, iterable, v=None): + # There is no equivalent method for counters because setting v=1 + # means that no element can have a count greater than one. + raise NotImplementedError( + 'Counter.fromkeys() is undefined. Use Counter(iterable) instead.') + + def update(*args, **kwds): + '''Like dict.update() but add counts instead of replacing them. + + Source can be an iterable, a dictionary, or another Counter instance. + + >>> c = Counter('which') + >>> c.update('witch') # add elements from another iterable + >>> d = Counter('watch') + >>> c.update(d) # add elements from another counter + >>> c['h'] # four 'h' in which, witch, and watch + 4 + + ''' + # The regular dict.update() operation makes no sense here because the + # replace behavior results in the some of original untouched counts + # being mixed-in with all of the other counts for a mismash that + # doesn't have a straight-forward interpretation in most counting + # contexts. Instead, we implement straight-addition. Both the inputs + # and outputs are allowed to contain zero and negative counts. + + if not args: + raise TypeError("descriptor 'update' of 'Counter' object " + "needs an argument") + self = args[0] + args = args[1:] + if len(args) > 1: + raise TypeError('expected at most 1 arguments, got %d' % len(args)) + iterable = args[0] if args else None + if iterable is not None: + if isinstance(iterable, Mapping): + if self: + self_get = self.get + for elem, count in iterable.iteritems(): + self[elem] = self_get(elem, 0) + count + else: + super(Counter, self).update(iterable) # fast path when counter is empty + else: + self_get = self.get + for elem in iterable: + self[elem] = self_get(elem, 0) + 1 + if kwds: + self.update(kwds) + + def subtract(*args, **kwds): + '''Like dict.update() but subtracts counts instead of replacing them. + Counts can be reduced below zero. Both the inputs and outputs are + allowed to contain zero and negative counts. + + Source can be an iterable, a dictionary, or another Counter instance. + + >>> c = Counter('which') + >>> c.subtract('witch') # subtract elements from another iterable + >>> c.subtract(Counter('watch')) # subtract elements from another counter + >>> c['h'] # 2 in which, minus 1 in witch, minus 1 in watch + 0 + >>> c['w'] # 1 in which, minus 1 in witch, minus 1 in watch + -1 + + ''' + if not args: + raise TypeError("descriptor 'subtract' of 'Counter' object " + "needs an argument") + self = args[0] + args = args[1:] + if len(args) > 1: + raise TypeError('expected at most 1 arguments, got %d' % len(args)) + iterable = args[0] if args else None + if iterable is not None: + self_get = self.get + if isinstance(iterable, Mapping): + for elem, count in iterable.items(): + self[elem] = self_get(elem, 0) - count + else: + for elem in iterable: + self[elem] = self_get(elem, 0) - 1 + if kwds: + self.subtract(kwds) + + def copy(self): + 'Return a shallow copy.' + return self.__class__(self) + + def __reduce__(self): + return self.__class__, (dict(self),) + + def __delitem__(self, elem): + 'Like dict.__delitem__() but does not raise KeyError for missing values.' + if elem in self: + super(Counter, self).__delitem__(elem) + + def __repr__(self): + if not self: + return '%s()' % self.__class__.__name__ + items = ', '.join(map('%r: %r'.__mod__, self.most_common())) + return '%s({%s})' % (self.__class__.__name__, items) + + # Multiset-style mathematical operations discussed in: + # Knuth TAOCP Volume II section 4.6.3 exercise 19 + # and at http://en.wikipedia.org/wiki/Multiset + # + # Outputs guaranteed to only include positive counts. + # + # To strip negative and zero counts, add-in an empty counter: + # c += Counter() + + def __add__(self, other): + '''Add counts from two counters. + + >>> Counter('abbb') + Counter('bcc') + Counter({'b': 4, 'c': 2, 'a': 1}) + + ''' + if not isinstance(other, Counter): + return NotImplemented + result = Counter() + for elem, count in self.items(): + newcount = count + other[elem] + if newcount > 0: + result[elem] = newcount + for elem, count in other.items(): + if elem not in self and count > 0: + result[elem] = count + return result + + def __sub__(self, other): + ''' Subtract count, but keep only results with positive counts. + + >>> Counter('abbbc') - Counter('bccd') + Counter({'b': 2, 'a': 1}) + + ''' + if not isinstance(other, Counter): + return NotImplemented + result = Counter() + for elem, count in self.items(): + newcount = count - other[elem] + if newcount > 0: + result[elem] = newcount + for elem, count in other.items(): + if elem not in self and count < 0: + result[elem] = 0 - count + return result + + def __or__(self, other): + '''Union is the maximum of value in either of the input counters. + + >>> Counter('abbb') | Counter('bcc') + Counter({'b': 3, 'c': 2, 'a': 1}) + + ''' + if not isinstance(other, Counter): + return NotImplemented + result = Counter() + for elem, count in self.items(): + other_count = other[elem] + newcount = other_count if count < other_count else count + if newcount > 0: + result[elem] = newcount + for elem, count in other.items(): + if elem not in self and count > 0: + result[elem] = count + return result + + def __and__(self, other): + ''' Intersection is the minimum of corresponding counts. + + >>> Counter('abbb') & Counter('bcc') + Counter({'b': 1}) + + ''' + if not isinstance(other, Counter): + return NotImplemented + result = Counter() + for elem, count in self.items(): + other_count = other[elem] + newcount = count if count < other_count else other_count + if newcount > 0: + result[elem] = newcount + return result + + +if __name__ == '__main__': + # verify that instances can be pickled + from cPickle import loads, dumps + Point = namedtuple('Point', 'x, y', True) + p = Point(x=10, y=20) + assert p == loads(dumps(p)) + + # test and demonstrate ability to override methods + class Point(namedtuple('Point', 'x y')): + __slots__ = () + @property + def hypot(self): + return (self.x ** 2 + self.y ** 2) ** 0.5 + def __str__(self): + return 'Point: x=%6.3f y=%6.3f hypot=%6.3f' % (self.x, self.y, self.hypot) + + for p in Point(3, 4), Point(14, 5/7.): + print p + + class Point(namedtuple('Point', 'x y')): + 'Point class with optimized _make() and _replace() without error-checking' + __slots__ = () + _make = classmethod(tuple.__new__) + def _replace(self, _map=map, **kwds): + return self._make(_map(kwds.get, ('x', 'y'), self)) + + print Point(11, 22)._replace(x=100) + + Point3D = namedtuple('Point3D', Point._fields + ('z',)) + print Point3D.__doc__ + + import doctest + TestResults = namedtuple('TestResults', 'failed attempted') + print TestResults(*doctest.testmod()) diff --git a/contextlib.py b/contextlib.py new file mode 100644 index 0000000..f05205b --- /dev/null +++ b/contextlib.py @@ -0,0 +1,154 @@ +"""Utilities for with-statement contexts. See PEP 343.""" + +import sys +from functools import wraps +from warnings import warn + +__all__ = ["contextmanager", "nested", "closing"] + +class GeneratorContextManager(object): + """Helper for @contextmanager decorator.""" + + def __init__(self, gen): + self.gen = gen + + def __enter__(self): + try: + return self.gen.next() + except StopIteration: + raise RuntimeError("generator didn't yield") + + def __exit__(self, type, value, traceback): + if type is None: + try: + self.gen.next() + except StopIteration: + return + else: + raise RuntimeError("generator didn't stop") + else: + if value is None: + # Need to force instantiation so we can reliably + # tell if we get the same exception back + value = type() + try: + self.gen.throw(type, value, traceback) + raise RuntimeError("generator didn't stop after throw()") + except StopIteration, exc: + # Suppress the exception *unless* it's the same exception that + # was passed to throw(). This prevents a StopIteration + # raised inside the "with" statement from being suppressed + return exc is not value + except: + # only re-raise if it's *not* the exception that was + # passed to throw(), because __exit__() must not raise + # an exception unless __exit__() itself failed. But throw() + # has to raise the exception to signal propagation, so this + # fixes the impedance mismatch between the throw() protocol + # and the __exit__() protocol. + # + if sys.exc_info()[1] is not value: + raise + + +def contextmanager(func): + """@contextmanager decorator. + + Typical usage: + + @contextmanager + def some_generator(): + + try: + yield + finally: + + + This makes this: + + with some_generator() as : + + + equivalent to this: + + + try: + = + + finally: + + + """ + @wraps(func) + def helper(*args, **kwds): + return GeneratorContextManager(func(*args, **kwds)) + return helper + + +@contextmanager +def nested(*managers): + """Combine multiple context managers into a single nested context manager. + + This function has been deprecated in favour of the multiple manager form + of the with statement. + + The one advantage of this function over the multiple manager form of the + with statement is that argument unpacking allows it to be + used with a variable number of context managers as follows: + + with nested(*managers): + do_something() + + """ + warn("With-statements now directly support multiple context managers", + DeprecationWarning, 3) + exits = [] + vars = [] + exc = (None, None, None) + try: + for mgr in managers: + exit = mgr.__exit__ + enter = mgr.__enter__ + vars.append(enter()) + exits.append(exit) + yield vars + except: + exc = sys.exc_info() + finally: + while exits: + exit = exits.pop() + try: + if exit(*exc): + exc = (None, None, None) + except: + exc = sys.exc_info() + if exc != (None, None, None): + # Don't rely on sys.exc_info() still containing + # the right information. Another exception may + # have been raised and caught by an exit method + raise exc[0], exc[1], exc[2] + + +class closing(object): + """Context to automatically close something at the end of a block. + + Code like this: + + with closing(.open()) as f: + + + is equivalent to this: + + f = .open() + try: + + finally: + f.close() + + """ + def __init__(self, thing): + self.thing = thing + def __enter__(self): + return self.thing + def __exit__(self, *exc_info): + self.thing.close() diff --git a/functools.py b/functools.py new file mode 100644 index 0000000..53680b8 --- /dev/null +++ b/functools.py @@ -0,0 +1,100 @@ +"""functools.py - Tools for working with functions and callable objects +""" +# Python module wrapper for _functools C module +# to allow utilities written in Python to be added +# to the functools module. +# Written by Nick Coghlan +# Copyright (C) 2006 Python Software Foundation. +# See C source code for _functools credits/copyright + +from _functools import partial, reduce + +# update_wrapper() and wraps() are tools to help write +# wrapper functions that can handle naive introspection + +WRAPPER_ASSIGNMENTS = ('__module__', '__name__', '__doc__') +WRAPPER_UPDATES = ('__dict__',) +def update_wrapper(wrapper, + wrapped, + assigned = WRAPPER_ASSIGNMENTS, + updated = WRAPPER_UPDATES): + """Update a wrapper function to look like the wrapped function + + wrapper is the function to be updated + wrapped is the original function + assigned is a tuple naming the attributes assigned directly + from the wrapped function to the wrapper function (defaults to + functools.WRAPPER_ASSIGNMENTS) + updated is a tuple naming the attributes of the wrapper that + are updated with the corresponding attribute from the wrapped + function (defaults to functools.WRAPPER_UPDATES) + """ + for attr in assigned: + setattr(wrapper, attr, getattr(wrapped, attr)) + for attr in updated: + getattr(wrapper, attr).update(getattr(wrapped, attr, {})) + # Return the wrapper so this can be used as a decorator via partial() + return wrapper + +def wraps(wrapped, + assigned = WRAPPER_ASSIGNMENTS, + updated = WRAPPER_UPDATES): + """Decorator factory to apply update_wrapper() to a wrapper function + + Returns a decorator that invokes update_wrapper() with the decorated + function as the wrapper argument and the arguments to wraps() as the + remaining arguments. Default arguments are as for update_wrapper(). + This is a convenience function to simplify applying partial() to + update_wrapper(). + """ + return partial(update_wrapper, wrapped=wrapped, + assigned=assigned, updated=updated) + +def total_ordering(cls): + """Class decorator that fills in missing ordering methods""" + convert = { + '__lt__': [('__gt__', lambda self, other: not (self < other or self == other)), + ('__le__', lambda self, other: self < other or self == other), + ('__ge__', lambda self, other: not self < other)], + '__le__': [('__ge__', lambda self, other: not self <= other or self == other), + ('__lt__', lambda self, other: self <= other and not self == other), + ('__gt__', lambda self, other: not self <= other)], + '__gt__': [('__lt__', lambda self, other: not (self > other or self == other)), + ('__ge__', lambda self, other: self > other or self == other), + ('__le__', lambda self, other: not self > other)], + '__ge__': [('__le__', lambda self, other: (not self >= other) or self == other), + ('__gt__', lambda self, other: self >= other and not self == other), + ('__lt__', lambda self, other: not self >= other)] + } + roots = set(dir(cls)) & set(convert) + if not roots: + raise ValueError('must define at least one ordering operation: < > <= >=') + root = max(roots) # prefer __lt__ to __le__ to __gt__ to __ge__ + for opname, opfunc in convert[root]: + if opname not in roots: + opfunc.__name__ = opname + opfunc.__doc__ = getattr(int, opname).__doc__ + setattr(cls, opname, opfunc) + return cls + +def cmp_to_key(mycmp): + """Convert a cmp= function into a key= function""" + class K(object): + __slots__ = ['obj'] + def __init__(self, obj, *args): + self.obj = obj + def __lt__(self, other): + return mycmp(self.obj, other.obj) < 0 + def __gt__(self, other): + return mycmp(self.obj, other.obj) > 0 + def __eq__(self, other): + return mycmp(self.obj, other.obj) == 0 + def __le__(self, other): + return mycmp(self.obj, other.obj) <= 0 + def __ge__(self, other): + return mycmp(self.obj, other.obj) >= 0 + def __ne__(self, other): + return mycmp(self.obj, other.obj) != 0 + def __hash__(self): + raise TypeError('hash not implemented') + return K diff --git a/genericpath.py b/genericpath.py new file mode 100644 index 0000000..2648e54 --- /dev/null +++ b/genericpath.py @@ -0,0 +1,113 @@ +""" +Path operations common to more than one OS +Do not use directly. The OS specific modules import the appropriate +functions from this module themselves. +""" +import os +import stat + +__all__ = ['commonprefix', 'exists', 'getatime', 'getctime', 'getmtime', + 'getsize', 'isdir', 'isfile'] + + +try: + _unicode = unicode +except NameError: + # If Python is built without Unicode support, the unicode type + # will not exist. Fake one. + class _unicode(object): + pass + +# Does a path exist? +# This is false for dangling symbolic links on systems that support them. +def exists(path): + """Test whether a path exists. Returns False for broken symbolic links""" + try: + os.stat(path) + except os.error: + return False + return True + + +# This follows symbolic links, so both islink() and isdir() can be true +# for the same path on systems that support symlinks +def isfile(path): + """Test whether a path is a regular file""" + try: + st = os.stat(path) + except os.error: + return False + return stat.S_ISREG(st.st_mode) + + +# Is a path a directory? +# This follows symbolic links, so both islink() and isdir() +# can be true for the same path on systems that support symlinks +def isdir(s): + """Return true if the pathname refers to an existing directory.""" + try: + st = os.stat(s) + except os.error: + return False + return stat.S_ISDIR(st.st_mode) + + +def getsize(filename): + """Return the size of a file, reported by os.stat().""" + return os.stat(filename).st_size + + +def getmtime(filename): + """Return the last modification time of a file, reported by os.stat().""" + return os.stat(filename).st_mtime + + +def getatime(filename): + """Return the last access time of a file, reported by os.stat().""" + return os.stat(filename).st_atime + + +def getctime(filename): + """Return the metadata change time of a file, reported by os.stat().""" + return os.stat(filename).st_ctime + + +# Return the longest prefix of all list elements. +def commonprefix(m): + "Given a list of pathnames, returns the longest common leading component" + if not m: return '' + s1 = min(m) + s2 = max(m) + for i, c in enumerate(s1): + if c != s2[i]: + return s1[:i] + return s1 + +# Split a path in root and extension. +# The extension is everything starting at the last dot in the last +# pathname component; the root is everything before that. +# It is always true that root + ext == p. + +# Generic implementation of splitext, to be parametrized with +# the separators +def _splitext(p, sep, altsep, extsep): + """Split the extension from a pathname. + + Extension is everything from the last dot to the end, ignoring + leading dots. Returns "(root, ext)"; ext may be empty.""" + + sepIndex = p.rfind(sep) + if altsep: + altsepIndex = p.rfind(altsep) + sepIndex = max(sepIndex, altsepIndex) + + dotIndex = p.rfind(extsep) + if dotIndex > sepIndex: + # skip all leading dots + filenameIndex = sepIndex + 1 + while filenameIndex < dotIndex: + if p[filenameIndex] != extsep: + return p[:dotIndex], p[dotIndex:] + filenameIndex += 1 + + return p, '' diff --git a/hashlib.py b/hashlib.py new file mode 100644 index 0000000..bbd06b9 --- /dev/null +++ b/hashlib.py @@ -0,0 +1,221 @@ +# $Id$ +# +# Copyright (C) 2005 Gregory P. Smith (greg@krypto.org) +# Licensed to PSF under a Contributor Agreement. +# + +__doc__ = """hashlib module - A common interface to many hash functions. + +new(name, string='') - returns a new hash object implementing the + given hash function; initializing the hash + using the given string data. + +Named constructor functions are also available, these are much faster +than using new(): + +md5(), sha1(), sha224(), sha256(), sha384(), and sha512() + +More algorithms may be available on your platform but the above are guaranteed +to exist. See the algorithms_guaranteed and algorithms_available attributes +to find out what algorithm names can be passed to new(). + +NOTE: If you want the adler32 or crc32 hash functions they are available in +the zlib module. + +Choose your hash function wisely. Some have known collision weaknesses. +sha384 and sha512 will be slow on 32 bit platforms. + +Hash objects have these methods: + - update(arg): Update the hash object with the string arg. Repeated calls + are equivalent to a single call with the concatenation of all + the arguments. + - digest(): Return the digest of the strings passed to the update() method + so far. This may contain non-ASCII characters, including + NUL bytes. + - hexdigest(): Like digest() except the digest is returned as a string of + double length, containing only hexadecimal digits. + - copy(): Return a copy (clone) of the hash object. This can be used to + efficiently compute the digests of strings that share a common + initial substring. + +For example, to obtain the digest of the string 'Nobody inspects the +spammish repetition': + + >>> import hashlib + >>> m = hashlib.md5() + >>> m.update("Nobody inspects") + >>> m.update(" the spammish repetition") + >>> m.digest() + '\\xbbd\\x9c\\x83\\xdd\\x1e\\xa5\\xc9\\xd9\\xde\\xc9\\xa1\\x8d\\xf0\\xff\\xe9' + +More condensed: + + >>> hashlib.sha224("Nobody inspects the spammish repetition").hexdigest() + 'a4337bc45a8fc544c03f52dc550cd6e1e87021bc896588bd79e901e2' + +""" + +# This tuple and __get_builtin_constructor() must be modified if a new +# always available algorithm is added. +__always_supported = ('md5', 'sha1', 'sha224', 'sha256', 'sha384', 'sha512') + +algorithms_guaranteed = set(__always_supported) +algorithms_available = set(__always_supported) + +algorithms = __always_supported + +__all__ = __always_supported + ('new', 'algorithms_guaranteed', + 'algorithms_available', 'algorithms', + 'pbkdf2_hmac') + + +def __get_builtin_constructor(name): + try: + if name in ('SHA1', 'sha1'): + import _sha + return _sha.new + elif name in ('MD5', 'md5'): + import _md5 + return _md5.new + elif name in ('SHA256', 'sha256', 'SHA224', 'sha224'): + import _sha256 + bs = name[3:] + if bs == '256': + return _sha256.sha256 + elif bs == '224': + return _sha256.sha224 + elif name in ('SHA512', 'sha512', 'SHA384', 'sha384'): + import _sha512 + bs = name[3:] + if bs == '512': + return _sha512.sha512 + elif bs == '384': + return _sha512.sha384 + except ImportError: + pass # no extension module, this hash is unsupported. + + raise ValueError('unsupported hash type ' + name) + + +def __get_openssl_constructor(name): + try: + f = getattr(_hashlib, 'openssl_' + name) + # Allow the C module to raise ValueError. The function will be + # defined but the hash not actually available thanks to OpenSSL. + f() + # Use the C function directly (very fast) + return f + except (AttributeError, ValueError): + return __get_builtin_constructor(name) + + +def __py_new(name, string=''): + """new(name, string='') - Return a new hashing object using the named algorithm; + optionally initialized with a string. + """ + return __get_builtin_constructor(name)(string) + + +def __hash_new(name, string=''): + """new(name, string='') - Return a new hashing object using the named algorithm; + optionally initialized with a string. + """ + try: + return _hashlib.new(name, string) + except ValueError: + # If the _hashlib module (OpenSSL) doesn't support the named + # hash, try using our builtin implementations. + # This allows for SHA224/256 and SHA384/512 support even though + # the OpenSSL library prior to 0.9.8 doesn't provide them. + return __get_builtin_constructor(name)(string) + + +try: + import _hashlib + new = __hash_new + __get_hash = __get_openssl_constructor + algorithms_available = algorithms_available.union( + _hashlib.openssl_md_meth_names) +except ImportError: + new = __py_new + __get_hash = __get_builtin_constructor + +for __func_name in __always_supported: + # try them all, some may not work due to the OpenSSL + # version not supporting that algorithm. + try: + globals()[__func_name] = __get_hash(__func_name) + except ValueError: + import logging + logging.exception('code for hash %s was not found.', __func_name) + + +try: + # OpenSSL's PKCS5_PBKDF2_HMAC requires OpenSSL 1.0+ with HMAC and SHA + from _hashlib import pbkdf2_hmac +except ImportError: + import binascii + import struct + + _trans_5C = b"".join(chr(x ^ 0x5C) for x in range(256)) + _trans_36 = b"".join(chr(x ^ 0x36) for x in range(256)) + + def pbkdf2_hmac(hash_name, password, salt, iterations, dklen=None): + """Password based key derivation function 2 (PKCS #5 v2.0) + + This Python implementations based on the hmac module about as fast + as OpenSSL's PKCS5_PBKDF2_HMAC for short passwords and much faster + for long passwords. + """ + if not isinstance(hash_name, str): + raise TypeError(hash_name) + + if not isinstance(password, (bytes, bytearray)): + password = bytes(buffer(password)) + if not isinstance(salt, (bytes, bytearray)): + salt = bytes(buffer(salt)) + + # Fast inline HMAC implementation + inner = new(hash_name) + outer = new(hash_name) + blocksize = getattr(inner, 'block_size', 64) + if len(password) > blocksize: + password = new(hash_name, password).digest() + password = password + b'\x00' * (blocksize - len(password)) + inner.update(password.translate(_trans_36)) + outer.update(password.translate(_trans_5C)) + + def prf(msg, inner=inner, outer=outer): + # PBKDF2_HMAC uses the password as key. We can re-use the same + # digest objects and just update copies to skip initialization. + icpy = inner.copy() + ocpy = outer.copy() + icpy.update(msg) + ocpy.update(icpy.digest()) + return ocpy.digest() + + if iterations < 1: + raise ValueError(iterations) + if dklen is None: + dklen = outer.digest_size + if dklen < 1: + raise ValueError(dklen) + + hex_format_string = "%%0%ix" % (new(hash_name).digest_size * 2) + + dkey = b'' + loop = 1 + while len(dkey) < dklen: + prev = prf(salt + struct.pack(b'>I', loop)) + rkey = int(binascii.hexlify(prev), 16) + for i in xrange(iterations - 1): + prev = prf(prev) + rkey ^= int(binascii.hexlify(prev), 16) + loop += 1 + dkey += binascii.unhexlify(hex_format_string % rkey) + + return dkey[:dklen] + +# Cleanup locals() +del __always_supported, __func_name, __get_hash +del __py_new, __hash_new, __get_openssl_constructor diff --git a/heapq.py b/heapq.py new file mode 100644 index 0000000..4b2c0c4 --- /dev/null +++ b/heapq.py @@ -0,0 +1,485 @@ +# -*- coding: latin-1 -*- + +"""Heap queue algorithm (a.k.a. priority queue). + +Heaps are arrays for which a[k] <= a[2*k+1] and a[k] <= a[2*k+2] for +all k, counting elements from 0. For the sake of comparison, +non-existing elements are considered to be infinite. The interesting +property of a heap is that a[0] is always its smallest element. + +Usage: + +heap = [] # creates an empty heap +heappush(heap, item) # pushes a new item on the heap +item = heappop(heap) # pops the smallest item from the heap +item = heap[0] # smallest item on the heap without popping it +heapify(x) # transforms list into a heap, in-place, in linear time +item = heapreplace(heap, item) # pops and returns smallest item, and adds + # new item; the heap size is unchanged + +Our API differs from textbook heap algorithms as follows: + +- We use 0-based indexing. This makes the relationship between the + index for a node and the indexes for its children slightly less + obvious, but is more suitable since Python uses 0-based indexing. + +- Our heappop() method returns the smallest item, not the largest. + +These two make it possible to view the heap as a regular Python list +without surprises: heap[0] is the smallest item, and heap.sort() +maintains the heap invariant! +""" + +# Original code by Kevin O'Connor, augmented by Tim Peters and Raymond Hettinger + +__about__ = """Heap queues + +[explanation by François Pinard] + +Heaps are arrays for which a[k] <= a[2*k+1] and a[k] <= a[2*k+2] for +all k, counting elements from 0. For the sake of comparison, +non-existing elements are considered to be infinite. The interesting +property of a heap is that a[0] is always its smallest element. + +The strange invariant above is meant to be an efficient memory +representation for a tournament. The numbers below are `k', not a[k]: + + 0 + + 1 2 + + 3 4 5 6 + + 7 8 9 10 11 12 13 14 + + 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 + + +In the tree above, each cell `k' is topping `2*k+1' and `2*k+2'. In +an usual binary tournament we see in sports, each cell is the winner +over the two cells it tops, and we can trace the winner down the tree +to see all opponents s/he had. However, in many computer applications +of such tournaments, we do not need to trace the history of a winner. +To be more memory efficient, when a winner is promoted, we try to +replace it by something else at a lower level, and the rule becomes +that a cell and the two cells it tops contain three different items, +but the top cell "wins" over the two topped cells. + +If this heap invariant is protected at all time, index 0 is clearly +the overall winner. The simplest algorithmic way to remove it and +find the "next" winner is to move some loser (let's say cell 30 in the +diagram above) into the 0 position, and then percolate this new 0 down +the tree, exchanging values, until the invariant is re-established. +This is clearly logarithmic on the total number of items in the tree. +By iterating over all items, you get an O(n ln n) sort. + +A nice feature of this sort is that you can efficiently insert new +items while the sort is going on, provided that the inserted items are +not "better" than the last 0'th element you extracted. This is +especially useful in simulation contexts, where the tree holds all +incoming events, and the "win" condition means the smallest scheduled +time. When an event schedule other events for execution, they are +scheduled into the future, so they can easily go into the heap. So, a +heap is a good structure for implementing schedulers (this is what I +used for my MIDI sequencer :-). + +Various structures for implementing schedulers have been extensively +studied, and heaps are good for this, as they are reasonably speedy, +the speed is almost constant, and the worst case is not much different +than the average case. However, there are other representations which +are more efficient overall, yet the worst cases might be terrible. + +Heaps are also very useful in big disk sorts. You most probably all +know that a big sort implies producing "runs" (which are pre-sorted +sequences, which size is usually related to the amount of CPU memory), +followed by a merging passes for these runs, which merging is often +very cleverly organised[1]. It is very important that the initial +sort produces the longest runs possible. Tournaments are a good way +to that. If, using all the memory available to hold a tournament, you +replace and percolate items that happen to fit the current run, you'll +produce runs which are twice the size of the memory for random input, +and much better for input fuzzily ordered. + +Moreover, if you output the 0'th item on disk and get an input which +may not fit in the current tournament (because the value "wins" over +the last output value), it cannot fit in the heap, so the size of the +heap decreases. The freed memory could be cleverly reused immediately +for progressively building a second heap, which grows at exactly the +same rate the first heap is melting. When the first heap completely +vanishes, you switch heaps and start a new run. Clever and quite +effective! + +In a word, heaps are useful memory structures to know. I use them in +a few applications, and I think it is good to keep a `heap' module +around. :-) + +-------------------- +[1] The disk balancing algorithms which are current, nowadays, are +more annoying than clever, and this is a consequence of the seeking +capabilities of the disks. On devices which cannot seek, like big +tape drives, the story was quite different, and one had to be very +clever to ensure (far in advance) that each tape movement will be the +most effective possible (that is, will best participate at +"progressing" the merge). Some tapes were even able to read +backwards, and this was also used to avoid the rewinding time. +Believe me, real good tape sorts were quite spectacular to watch! +From all times, sorting has always been a Great Art! :-) +""" + +__all__ = ['heappush', 'heappop', 'heapify', 'heapreplace', 'merge', + 'nlargest', 'nsmallest', 'heappushpop'] + +from itertools import islice, count, imap, izip, tee, chain +from operator import itemgetter + +def cmp_lt(x, y): + # Use __lt__ if available; otherwise, try __le__. + # In Py3.x, only __lt__ will be called. + return (x < y) if hasattr(x, '__lt__') else (not y <= x) + +def heappush(heap, item): + """Push item onto heap, maintaining the heap invariant.""" + heap.append(item) + _siftdown(heap, 0, len(heap)-1) + +def heappop(heap): + """Pop the smallest item off the heap, maintaining the heap invariant.""" + lastelt = heap.pop() # raises appropriate IndexError if heap is empty + if heap: + returnitem = heap[0] + heap[0] = lastelt + _siftup(heap, 0) + else: + returnitem = lastelt + return returnitem + +def heapreplace(heap, item): + """Pop and return the current smallest value, and add the new item. + + This is more efficient than heappop() followed by heappush(), and can be + more appropriate when using a fixed-size heap. Note that the value + returned may be larger than item! That constrains reasonable uses of + this routine unless written as part of a conditional replacement: + + if item > heap[0]: + item = heapreplace(heap, item) + """ + returnitem = heap[0] # raises appropriate IndexError if heap is empty + heap[0] = item + _siftup(heap, 0) + return returnitem + +def heappushpop(heap, item): + """Fast version of a heappush followed by a heappop.""" + if heap and cmp_lt(heap[0], item): + item, heap[0] = heap[0], item + _siftup(heap, 0) + return item + +def heapify(x): + """Transform list into a heap, in-place, in O(len(x)) time.""" + n = len(x) + # Transform bottom-up. The largest index there's any point to looking at + # is the largest with a child index in-range, so must have 2*i + 1 < n, + # or i < (n-1)/2. If n is even = 2*j, this is (2*j-1)/2 = j-1/2 so + # j-1 is the largest, which is n//2 - 1. If n is odd = 2*j+1, this is + # (2*j+1-1)/2 = j so j-1 is the largest, and that's again n//2-1. + for i in reversed(xrange(n//2)): + _siftup(x, i) + +def _heappushpop_max(heap, item): + """Maxheap version of a heappush followed by a heappop.""" + if heap and cmp_lt(item, heap[0]): + item, heap[0] = heap[0], item + _siftup_max(heap, 0) + return item + +def _heapify_max(x): + """Transform list into a maxheap, in-place, in O(len(x)) time.""" + n = len(x) + for i in reversed(range(n//2)): + _siftup_max(x, i) + +def nlargest(n, iterable): + """Find the n largest elements in a dataset. + + Equivalent to: sorted(iterable, reverse=True)[:n] + """ + if n < 0: + return [] + it = iter(iterable) + result = list(islice(it, n)) + if not result: + return result + heapify(result) + _heappushpop = heappushpop + for elem in it: + _heappushpop(result, elem) + result.sort(reverse=True) + return result + +def nsmallest(n, iterable): + """Find the n smallest elements in a dataset. + + Equivalent to: sorted(iterable)[:n] + """ + if n < 0: + return [] + it = iter(iterable) + result = list(islice(it, n)) + if not result: + return result + _heapify_max(result) + _heappushpop = _heappushpop_max + for elem in it: + _heappushpop(result, elem) + result.sort() + return result + +# 'heap' is a heap at all indices >= startpos, except possibly for pos. pos +# is the index of a leaf with a possibly out-of-order value. Restore the +# heap invariant. +def _siftdown(heap, startpos, pos): + newitem = heap[pos] + # Follow the path to the root, moving parents down until finding a place + # newitem fits. + while pos > startpos: + parentpos = (pos - 1) >> 1 + parent = heap[parentpos] + if cmp_lt(newitem, parent): + heap[pos] = parent + pos = parentpos + continue + break + heap[pos] = newitem + +# The child indices of heap index pos are already heaps, and we want to make +# a heap at index pos too. We do this by bubbling the smaller child of +# pos up (and so on with that child's children, etc) until hitting a leaf, +# then using _siftdown to move the oddball originally at index pos into place. +# +# We *could* break out of the loop as soon as we find a pos where newitem <= +# both its children, but turns out that's not a good idea, and despite that +# many books write the algorithm that way. During a heap pop, the last array +# element is sifted in, and that tends to be large, so that comparing it +# against values starting from the root usually doesn't pay (= usually doesn't +# get us out of the loop early). See Knuth, Volume 3, where this is +# explained and quantified in an exercise. +# +# Cutting the # of comparisons is important, since these routines have no +# way to extract "the priority" from an array element, so that intelligence +# is likely to be hiding in custom __cmp__ methods, or in array elements +# storing (priority, record) tuples. Comparisons are thus potentially +# expensive. +# +# On random arrays of length 1000, making this change cut the number of +# comparisons made by heapify() a little, and those made by exhaustive +# heappop() a lot, in accord with theory. Here are typical results from 3 +# runs (3 just to demonstrate how small the variance is): +# +# Compares needed by heapify Compares needed by 1000 heappops +# -------------------------- -------------------------------- +# 1837 cut to 1663 14996 cut to 8680 +# 1855 cut to 1659 14966 cut to 8678 +# 1847 cut to 1660 15024 cut to 8703 +# +# Building the heap by using heappush() 1000 times instead required +# 2198, 2148, and 2219 compares: heapify() is more efficient, when +# you can use it. +# +# The total compares needed by list.sort() on the same lists were 8627, +# 8627, and 8632 (this should be compared to the sum of heapify() and +# heappop() compares): list.sort() is (unsurprisingly!) more efficient +# for sorting. + +def _siftup(heap, pos): + endpos = len(heap) + startpos = pos + newitem = heap[pos] + # Bubble up the smaller child until hitting a leaf. + childpos = 2*pos + 1 # leftmost child position + while childpos < endpos: + # Set childpos to index of smaller child. + rightpos = childpos + 1 + if rightpos < endpos and not cmp_lt(heap[childpos], heap[rightpos]): + childpos = rightpos + # Move the smaller child up. + heap[pos] = heap[childpos] + pos = childpos + childpos = 2*pos + 1 + # The leaf at pos is empty now. Put newitem there, and bubble it up + # to its final resting place (by sifting its parents down). + heap[pos] = newitem + _siftdown(heap, startpos, pos) + +def _siftdown_max(heap, startpos, pos): + 'Maxheap variant of _siftdown' + newitem = heap[pos] + # Follow the path to the root, moving parents down until finding a place + # newitem fits. + while pos > startpos: + parentpos = (pos - 1) >> 1 + parent = heap[parentpos] + if cmp_lt(parent, newitem): + heap[pos] = parent + pos = parentpos + continue + break + heap[pos] = newitem + +def _siftup_max(heap, pos): + 'Maxheap variant of _siftup' + endpos = len(heap) + startpos = pos + newitem = heap[pos] + # Bubble up the larger child until hitting a leaf. + childpos = 2*pos + 1 # leftmost child position + while childpos < endpos: + # Set childpos to index of larger child. + rightpos = childpos + 1 + if rightpos < endpos and not cmp_lt(heap[rightpos], heap[childpos]): + childpos = rightpos + # Move the larger child up. + heap[pos] = heap[childpos] + pos = childpos + childpos = 2*pos + 1 + # The leaf at pos is empty now. Put newitem there, and bubble it up + # to its final resting place (by sifting its parents down). + heap[pos] = newitem + _siftdown_max(heap, startpos, pos) + +# If available, use C implementation +try: + from _heapq import * +except ImportError: + pass + +def merge(*iterables): + '''Merge multiple sorted inputs into a single sorted output. + + Similar to sorted(itertools.chain(*iterables)) but returns a generator, + does not pull the data into memory all at once, and assumes that each of + the input streams is already sorted (smallest to largest). + + >>> list(merge([1,3,5,7], [0,2,4,8], [5,10,15,20], [], [25])) + [0, 1, 2, 3, 4, 5, 5, 7, 8, 10, 15, 20, 25] + + ''' + _heappop, _heapreplace, _StopIteration = heappop, heapreplace, StopIteration + _len = len + + h = [] + h_append = h.append + for itnum, it in enumerate(map(iter, iterables)): + try: + next = it.next + h_append([next(), itnum, next]) + except _StopIteration: + pass + heapify(h) + + while _len(h) > 1: + try: + while 1: + v, itnum, next = s = h[0] + yield v + s[0] = next() # raises StopIteration when exhausted + _heapreplace(h, s) # restore heap condition + except _StopIteration: + _heappop(h) # remove empty iterator + if h: + # fast case when only a single iterator remains + v, itnum, next = h[0] + yield v + for v in next.__self__: + yield v + +# Extend the implementations of nsmallest and nlargest to use a key= argument +_nsmallest = nsmallest +def nsmallest(n, iterable, key=None): + """Find the n smallest elements in a dataset. + + Equivalent to: sorted(iterable, key=key)[:n] + """ + # Short-cut for n==1 is to use min() when len(iterable)>0 + if n == 1: + it = iter(iterable) + head = list(islice(it, 1)) + if not head: + return [] + if key is None: + return [min(chain(head, it))] + return [min(chain(head, it), key=key)] + + # When n>=size, it's faster to use sorted() + try: + size = len(iterable) + except (TypeError, AttributeError): + pass + else: + if n >= size: + return sorted(iterable, key=key)[:n] + + # When key is none, use simpler decoration + if key is None: + it = izip(iterable, count()) # decorate + result = _nsmallest(n, it) + return map(itemgetter(0), result) # undecorate + + # General case, slowest method + in1, in2 = tee(iterable) + it = izip(imap(key, in1), count(), in2) # decorate + result = _nsmallest(n, it) + return map(itemgetter(2), result) # undecorate + +_nlargest = nlargest +def nlargest(n, iterable, key=None): + """Find the n largest elements in a dataset. + + Equivalent to: sorted(iterable, key=key, reverse=True)[:n] + """ + + # Short-cut for n==1 is to use max() when len(iterable)>0 + if n == 1: + it = iter(iterable) + head = list(islice(it, 1)) + if not head: + return [] + if key is None: + return [max(chain(head, it))] + return [max(chain(head, it), key=key)] + + # When n>=size, it's faster to use sorted() + try: + size = len(iterable) + except (TypeError, AttributeError): + pass + else: + if n >= size: + return sorted(iterable, key=key, reverse=True)[:n] + + # When key is none, use simpler decoration + if key is None: + it = izip(iterable, count(0,-1)) # decorate + result = _nlargest(n, it) + return map(itemgetter(0), result) # undecorate + + # General case, slowest method + in1, in2 = tee(iterable) + it = izip(imap(key, in1), count(0,-1), in2) # decorate + result = _nlargest(n, it) + return map(itemgetter(2), result) # undecorate + +if __name__ == "__main__": + # Simple sanity test + heap = [] + data = [1, 3, 5, 7, 9, 2, 4, 6, 8, 0] + for item in data: + heappush(heap, item) + sort = [] + while heap: + sort.append(heappop(heap)) + print sort + + import doctest + doctest.testmod() diff --git a/httplib.py b/httplib.py new file mode 100644 index 0000000..7223ba1 --- /dev/null +++ b/httplib.py @@ -0,0 +1,1445 @@ +r"""HTTP/1.1 client library + + + + +HTTPConnection goes through a number of "states", which define when a client +may legally make another request or fetch the response for a particular +request. This diagram details these state transitions: + + (null) + | + | HTTPConnection() + v + Idle + | + | putrequest() + v + Request-started + | + | ( putheader() )* endheaders() + v + Request-sent + | + | response = getresponse() + v + Unread-response [Response-headers-read] + |\____________________ + | | + | response.read() | putrequest() + v v + Idle Req-started-unread-response + ______/| + / | + response.read() | | ( putheader() )* endheaders() + v v + Request-started Req-sent-unread-response + | + | response.read() + v + Request-sent + +This diagram presents the following rules: + -- a second request may not be started until {response-headers-read} + -- a response [object] cannot be retrieved until {request-sent} + -- there is no differentiation between an unread response body and a + partially read response body + +Note: this enforcement is applied by the HTTPConnection class. The + HTTPResponse class does not enforce this state machine, which + implies sophisticated clients may accelerate the request/response + pipeline. Caution should be taken, though: accelerating the states + beyond the above pattern may imply knowledge of the server's + connection-close behavior for certain requests. For example, it + is impossible to tell whether the server will close the connection + UNTIL the response headers have been read; this means that further + requests cannot be placed into the pipeline until it is known that + the server will NOT be closing the connection. + +Logical State __state __response +------------- ------- ---------- +Idle _CS_IDLE None +Request-started _CS_REQ_STARTED None +Request-sent _CS_REQ_SENT None +Unread-response _CS_IDLE +Req-started-unread-response _CS_REQ_STARTED +Req-sent-unread-response _CS_REQ_SENT +""" + +from array import array +import os +import re +import socket +from sys import py3kwarning +from urlparse import urlsplit +import warnings +with warnings.catch_warnings(): + if py3kwarning: + warnings.filterwarnings("ignore", ".*mimetools has been removed", + DeprecationWarning) + import mimetools + +try: + from cStringIO import StringIO +except ImportError: + from StringIO import StringIO + +__all__ = ["HTTP", "HTTPResponse", "HTTPConnection", + "HTTPException", "NotConnected", "UnknownProtocol", + "UnknownTransferEncoding", "UnimplementedFileMode", + "IncompleteRead", "InvalidURL", "ImproperConnectionState", + "CannotSendRequest", "CannotSendHeader", "ResponseNotReady", + "BadStatusLine", "error", "responses"] + +HTTP_PORT = 80 +HTTPS_PORT = 443 + +_UNKNOWN = 'UNKNOWN' + +# connection states +_CS_IDLE = 'Idle' +_CS_REQ_STARTED = 'Request-started' +_CS_REQ_SENT = 'Request-sent' + +# status codes +# informational +CONTINUE = 100 +SWITCHING_PROTOCOLS = 101 +PROCESSING = 102 + +# successful +OK = 200 +CREATED = 201 +ACCEPTED = 202 +NON_AUTHORITATIVE_INFORMATION = 203 +NO_CONTENT = 204 +RESET_CONTENT = 205 +PARTIAL_CONTENT = 206 +MULTI_STATUS = 207 +IM_USED = 226 + +# redirection +MULTIPLE_CHOICES = 300 +MOVED_PERMANENTLY = 301 +FOUND = 302 +SEE_OTHER = 303 +NOT_MODIFIED = 304 +USE_PROXY = 305 +TEMPORARY_REDIRECT = 307 + +# client error +BAD_REQUEST = 400 +UNAUTHORIZED = 401 +PAYMENT_REQUIRED = 402 +FORBIDDEN = 403 +NOT_FOUND = 404 +METHOD_NOT_ALLOWED = 405 +NOT_ACCEPTABLE = 406 +PROXY_AUTHENTICATION_REQUIRED = 407 +REQUEST_TIMEOUT = 408 +CONFLICT = 409 +GONE = 410 +LENGTH_REQUIRED = 411 +PRECONDITION_FAILED = 412 +REQUEST_ENTITY_TOO_LARGE = 413 +REQUEST_URI_TOO_LONG = 414 +UNSUPPORTED_MEDIA_TYPE = 415 +REQUESTED_RANGE_NOT_SATISFIABLE = 416 +EXPECTATION_FAILED = 417 +UNPROCESSABLE_ENTITY = 422 +LOCKED = 423 +FAILED_DEPENDENCY = 424 +UPGRADE_REQUIRED = 426 + +# server error +INTERNAL_SERVER_ERROR = 500 +NOT_IMPLEMENTED = 501 +BAD_GATEWAY = 502 +SERVICE_UNAVAILABLE = 503 +GATEWAY_TIMEOUT = 504 +HTTP_VERSION_NOT_SUPPORTED = 505 +INSUFFICIENT_STORAGE = 507 +NOT_EXTENDED = 510 + +# Mapping status codes to official W3C names +responses = { + 100: 'Continue', + 101: 'Switching Protocols', + + 200: 'OK', + 201: 'Created', + 202: 'Accepted', + 203: 'Non-Authoritative Information', + 204: 'No Content', + 205: 'Reset Content', + 206: 'Partial Content', + + 300: 'Multiple Choices', + 301: 'Moved Permanently', + 302: 'Found', + 303: 'See Other', + 304: 'Not Modified', + 305: 'Use Proxy', + 306: '(Unused)', + 307: 'Temporary Redirect', + + 400: 'Bad Request', + 401: 'Unauthorized', + 402: 'Payment Required', + 403: 'Forbidden', + 404: 'Not Found', + 405: 'Method Not Allowed', + 406: 'Not Acceptable', + 407: 'Proxy Authentication Required', + 408: 'Request Timeout', + 409: 'Conflict', + 410: 'Gone', + 411: 'Length Required', + 412: 'Precondition Failed', + 413: 'Request Entity Too Large', + 414: 'Request-URI Too Long', + 415: 'Unsupported Media Type', + 416: 'Requested Range Not Satisfiable', + 417: 'Expectation Failed', + + 500: 'Internal Server Error', + 501: 'Not Implemented', + 502: 'Bad Gateway', + 503: 'Service Unavailable', + 504: 'Gateway Timeout', + 505: 'HTTP Version Not Supported', +} + +# maximal amount of data to read at one time in _safe_read +MAXAMOUNT = 1048576 + +# maximal line length when calling readline(). +_MAXLINE = 65536 + +# maximum amount of headers accepted +_MAXHEADERS = 100 + +# Header name/value ABNF (http://tools.ietf.org/html/rfc7230#section-3.2) +# +# VCHAR = %x21-7E +# obs-text = %x80-FF +# header-field = field-name ":" OWS field-value OWS +# field-name = token +# field-value = *( field-content / obs-fold ) +# field-content = field-vchar [ 1*( SP / HTAB ) field-vchar ] +# field-vchar = VCHAR / obs-text +# +# obs-fold = CRLF 1*( SP / HTAB ) +# ; obsolete line folding +# ; see Section 3.2.4 + +# token = 1*tchar +# +# tchar = "!" / "#" / "$" / "%" / "&" / "'" / "*" +# / "+" / "-" / "." / "^" / "_" / "`" / "|" / "~" +# / DIGIT / ALPHA +# ; any VCHAR, except delimiters +# +# VCHAR defined in http://tools.ietf.org/html/rfc5234#appendix-B.1 + +# the patterns for both name and value are more leniant than RFC +# definitions to allow for backwards compatibility +_is_legal_header_name = re.compile(r'\A[^:\s][^:\r\n]*\Z').match +_is_illegal_header_value = re.compile(r'\n(?![ \t])|\r(?![ \t\n])').search + +# We always set the Content-Length header for these methods because some +# servers will otherwise respond with a 411 +_METHODS_EXPECTING_BODY = {'PATCH', 'POST', 'PUT'} + + +class HTTPMessage(mimetools.Message): + + def addheader(self, key, value): + """Add header for field key handling repeats.""" + prev = self.dict.get(key) + if prev is None: + self.dict[key] = value + else: + combined = ", ".join((prev, value)) + self.dict[key] = combined + + def addcontinue(self, key, more): + """Add more field data from a continuation line.""" + prev = self.dict[key] + self.dict[key] = prev + "\n " + more + + def readheaders(self): + """Read header lines. + + Read header lines up to the entirely blank line that terminates them. + The (normally blank) line that ends the headers is skipped, but not + included in the returned list. If a non-header line ends the headers, + (which is an error), an attempt is made to backspace over it; it is + never included in the returned list. + + The variable self.status is set to the empty string if all went well, + otherwise it is an error message. The variable self.headers is a + completely uninterpreted list of lines contained in the header (so + printing them will reproduce the header exactly as it appears in the + file). + + If multiple header fields with the same name occur, they are combined + according to the rules in RFC 2616 sec 4.2: + + Appending each subsequent field-value to the first, each separated + by a comma. The order in which header fields with the same field-name + are received is significant to the interpretation of the combined + field value. + """ + # XXX The implementation overrides the readheaders() method of + # rfc822.Message. The base class design isn't amenable to + # customized behavior here so the method here is a copy of the + # base class code with a few small changes. + + self.dict = {} + self.unixfrom = '' + self.headers = hlist = [] + self.status = '' + headerseen = "" + firstline = 1 + startofline = unread = tell = None + if hasattr(self.fp, 'unread'): + unread = self.fp.unread + elif self.seekable: + tell = self.fp.tell + while True: + if len(hlist) > _MAXHEADERS: + raise HTTPException("got more than %d headers" % _MAXHEADERS) + if tell: + try: + startofline = tell() + except IOError: + startofline = tell = None + self.seekable = 0 + line = self.fp.readline(_MAXLINE + 1) + if len(line) > _MAXLINE: + raise LineTooLong("header line") + if not line: + self.status = 'EOF in headers' + break + # Skip unix From name time lines + if firstline and line.startswith('From '): + self.unixfrom = self.unixfrom + line + continue + firstline = 0 + if headerseen and line[0] in ' \t': + # XXX Not sure if continuation lines are handled properly + # for http and/or for repeating headers + # It's a continuation line. + hlist.append(line) + self.addcontinue(headerseen, line.strip()) + continue + elif self.iscomment(line): + # It's a comment. Ignore it. + continue + elif self.islast(line): + # Note! No pushback here! The delimiter line gets eaten. + break + headerseen = self.isheader(line) + if headerseen: + # It's a legal header line, save it. + hlist.append(line) + self.addheader(headerseen, line[len(headerseen)+1:].strip()) + continue + elif headerseen is not None: + # An empty header name. These aren't allowed in HTTP, but it's + # probably a benign mistake. Don't add the header, just keep + # going. + continue + else: + # It's not a header line; throw it back and stop here. + if not self.dict: + self.status = 'No headers' + else: + self.status = 'Non-header line where header expected' + # Try to undo the read. + if unread: + unread(line) + elif tell: + self.fp.seek(startofline) + else: + self.status = self.status + '; bad seek' + break + +class HTTPResponse: + + # strict: If true, raise BadStatusLine if the status line can't be + # parsed as a valid HTTP/1.0 or 1.1 status line. By default it is + # false because it prevents clients from talking to HTTP/0.9 + # servers. Note that a response with a sufficiently corrupted + # status line will look like an HTTP/0.9 response. + + # See RFC 2616 sec 19.6 and RFC 1945 sec 6 for details. + + def __init__(self, sock, debuglevel=0, strict=0, method=None, buffering=False): + if buffering: + # The caller won't be using any sock.recv() calls, so buffering + # is fine and recommended for performance. + self.fp = sock.makefile('rb') + else: + # The buffer size is specified as zero, because the headers of + # the response are read with readline(). If the reads were + # buffered the readline() calls could consume some of the + # response, which make be read via a recv() on the underlying + # socket. + self.fp = sock.makefile('rb', 0) + self.debuglevel = debuglevel + self.strict = strict + self._method = method + + self.msg = None + + # from the Status-Line of the response + self.version = _UNKNOWN # HTTP-Version + self.status = _UNKNOWN # Status-Code + self.reason = _UNKNOWN # Reason-Phrase + + self.chunked = _UNKNOWN # is "chunked" being used? + self.chunk_left = _UNKNOWN # bytes left to read in current chunk + self.length = _UNKNOWN # number of bytes left in response + self.will_close = _UNKNOWN # conn will close at end of response + + def _read_status(self): + # Initialize with Simple-Response defaults + line = self.fp.readline(_MAXLINE + 1) + if len(line) > _MAXLINE: + raise LineTooLong("header line") + if self.debuglevel > 0: + print "reply:", repr(line) + if not line: + # Presumably, the server closed the connection before + # sending a valid response. + raise BadStatusLine(line) + try: + [version, status, reason] = line.split(None, 2) + except ValueError: + try: + [version, status] = line.split(None, 1) + reason = "" + except ValueError: + # empty version will cause next test to fail and status + # will be treated as 0.9 response. + version = "" + if not version.startswith('HTTP/'): + if self.strict: + self.close() + raise BadStatusLine(line) + else: + # assume it's a Simple-Response from an 0.9 server + self.fp = LineAndFileWrapper(line, self.fp) + return "HTTP/0.9", 200, "" + + # The status code is a three-digit number + try: + status = int(status) + if status < 100 or status > 999: + raise BadStatusLine(line) + except ValueError: + raise BadStatusLine(line) + return version, status, reason + + def begin(self): + if self.msg is not None: + # we've already started reading the response + return + + # read until we get a non-100 response + while True: + version, status, reason = self._read_status() + if status != CONTINUE: + break + # skip the header from the 100 response + while True: + skip = self.fp.readline(_MAXLINE + 1) + if len(skip) > _MAXLINE: + raise LineTooLong("header line") + skip = skip.strip() + if not skip: + break + if self.debuglevel > 0: + print "header:", skip + + self.status = status + self.reason = reason.strip() + if version == 'HTTP/1.0': + self.version = 10 + elif version.startswith('HTTP/1.'): + self.version = 11 # use HTTP/1.1 code for HTTP/1.x where x>=1 + elif version == 'HTTP/0.9': + self.version = 9 + else: + raise UnknownProtocol(version) + + if self.version == 9: + self.length = None + self.chunked = 0 + self.will_close = 1 + self.msg = HTTPMessage(StringIO()) + return + + self.msg = HTTPMessage(self.fp, 0) + if self.debuglevel > 0: + for hdr in self.msg.headers: + print "header:", hdr, + + # don't let the msg keep an fp + self.msg.fp = None + + # are we using the chunked-style of transfer encoding? + tr_enc = self.msg.getheader('transfer-encoding') + if tr_enc and tr_enc.lower() == "chunked": + self.chunked = 1 + self.chunk_left = None + else: + self.chunked = 0 + + # will the connection close at the end of the response? + self.will_close = self._check_close() + + # do we have a Content-Length? + # NOTE: RFC 2616, S4.4, #3 says we ignore this if tr_enc is "chunked" + length = self.msg.getheader('content-length') + if length and not self.chunked: + try: + self.length = int(length) + except ValueError: + self.length = None + else: + if self.length < 0: # ignore nonsensical negative lengths + self.length = None + else: + self.length = None + + # does the body have a fixed length? (of zero) + if (status == NO_CONTENT or status == NOT_MODIFIED or + 100 <= status < 200 or # 1xx codes + self._method == 'HEAD'): + self.length = 0 + + # if the connection remains open, and we aren't using chunked, and + # a content-length was not provided, then assume that the connection + # WILL close. + if not self.will_close and \ + not self.chunked and \ + self.length is None: + self.will_close = 1 + + def _check_close(self): + conn = self.msg.getheader('connection') + if self.version == 11: + # An HTTP/1.1 proxy is assumed to stay open unless + # explicitly closed. + conn = self.msg.getheader('connection') + if conn and "close" in conn.lower(): + return True + return False + + # Some HTTP/1.0 implementations have support for persistent + # connections, using rules different than HTTP/1.1. + + # For older HTTP, Keep-Alive indicates persistent connection. + if self.msg.getheader('keep-alive'): + return False + + # At least Akamai returns a "Connection: Keep-Alive" header, + # which was supposed to be sent by the client. + if conn and "keep-alive" in conn.lower(): + return False + + # Proxy-Connection is a netscape hack. + pconn = self.msg.getheader('proxy-connection') + if pconn and "keep-alive" in pconn.lower(): + return False + + # otherwise, assume it will close + return True + + def close(self): + fp = self.fp + if fp: + self.fp = None + fp.close() + + def isclosed(self): + # NOTE: it is possible that we will not ever call self.close(). This + # case occurs when will_close is TRUE, length is None, and we + # read up to the last byte, but NOT past it. + # + # IMPLIES: if will_close is FALSE, then self.close() will ALWAYS be + # called, meaning self.isclosed() is meaningful. + return self.fp is None + + # XXX It would be nice to have readline and __iter__ for this, too. + + def read(self, amt=None): + if self.fp is None: + return '' + + if self._method == 'HEAD': + self.close() + return '' + + if self.chunked: + return self._read_chunked(amt) + + if amt is None: + # unbounded read + if self.length is None: + s = self.fp.read() + else: + try: + s = self._safe_read(self.length) + except IncompleteRead: + self.close() + raise + self.length = 0 + self.close() # we read everything + return s + + if self.length is not None: + if amt > self.length: + # clip the read to the "end of response" + amt = self.length + + # we do not use _safe_read() here because this may be a .will_close + # connection, and the user is reading more bytes than will be provided + # (for example, reading in 1k chunks) + s = self.fp.read(amt) + if not s and amt: + # Ideally, we would raise IncompleteRead if the content-length + # wasn't satisfied, but it might break compatibility. + self.close() + if self.length is not None: + self.length -= len(s) + if not self.length: + self.close() + + return s + + def _read_chunked(self, amt): + assert self.chunked != _UNKNOWN + chunk_left = self.chunk_left + value = [] + while True: + if chunk_left is None: + line = self.fp.readline(_MAXLINE + 1) + if len(line) > _MAXLINE: + raise LineTooLong("chunk size") + i = line.find(';') + if i >= 0: + line = line[:i] # strip chunk-extensions + try: + chunk_left = int(line, 16) + except ValueError: + # close the connection as protocol synchronisation is + # probably lost + self.close() + raise IncompleteRead(''.join(value)) + if chunk_left == 0: + break + if amt is None: + value.append(self._safe_read(chunk_left)) + elif amt < chunk_left: + value.append(self._safe_read(amt)) + self.chunk_left = chunk_left - amt + return ''.join(value) + elif amt == chunk_left: + value.append(self._safe_read(amt)) + self._safe_read(2) # toss the CRLF at the end of the chunk + self.chunk_left = None + return ''.join(value) + else: + value.append(self._safe_read(chunk_left)) + amt -= chunk_left + + # we read the whole chunk, get another + self._safe_read(2) # toss the CRLF at the end of the chunk + chunk_left = None + + # read and discard trailer up to the CRLF terminator + ### note: we shouldn't have any trailers! + while True: + line = self.fp.readline(_MAXLINE + 1) + if len(line) > _MAXLINE: + raise LineTooLong("trailer line") + if not line: + # a vanishingly small number of sites EOF without + # sending the trailer + break + if line == '\r\n': + break + + # we read everything; close the "file" + self.close() + + return ''.join(value) + + def _safe_read(self, amt): + """Read the number of bytes requested, compensating for partial reads. + + Normally, we have a blocking socket, but a read() can be interrupted + by a signal (resulting in a partial read). + + Note that we cannot distinguish between EOF and an interrupt when zero + bytes have been read. IncompleteRead() will be raised in this + situation. + + This function should be used when bytes "should" be present for + reading. If the bytes are truly not available (due to EOF), then the + IncompleteRead exception can be used to detect the problem. + """ + # NOTE(gps): As of svn r74426 socket._fileobject.read(x) will never + # return less than x bytes unless EOF is encountered. It now handles + # signal interruptions (socket.error EINTR) internally. This code + # never caught that exception anyways. It seems largely pointless. + # self.fp.read(amt) will work fine. + s = [] + while amt > 0: + chunk = self.fp.read(min(amt, MAXAMOUNT)) + if not chunk: + raise IncompleteRead(''.join(s), amt) + s.append(chunk) + amt -= len(chunk) + return ''.join(s) + + def fileno(self): + return self.fp.fileno() + + def getheader(self, name, default=None): + if self.msg is None: + raise ResponseNotReady() + return self.msg.getheader(name, default) + + def getheaders(self): + """Return list of (header, value) tuples.""" + if self.msg is None: + raise ResponseNotReady() + return self.msg.items() + + +class HTTPConnection: + + _http_vsn = 11 + _http_vsn_str = 'HTTP/1.1' + + response_class = HTTPResponse + default_port = HTTP_PORT + auto_open = 1 + debuglevel = 0 + strict = 0 + + def __init__(self, host, port=None, strict=None, + timeout=socket._GLOBAL_DEFAULT_TIMEOUT, source_address=None): + self.timeout = timeout + self.source_address = source_address + self.sock = None + self._buffer = [] + self.__response = None + self.__state = _CS_IDLE + self._method = None + self._tunnel_host = None + self._tunnel_port = None + self._tunnel_headers = {} + if strict is not None: + self.strict = strict + + (self.host, self.port) = self._get_hostport(host, port) + + # This is stored as an instance variable to allow unittests + # to replace with a suitable mock + self._create_connection = socket.create_connection + + def set_tunnel(self, host, port=None, headers=None): + """ Set up host and port for HTTP CONNECT tunnelling. + + In a connection that uses HTTP Connect tunneling, the host passed to the + constructor is used as proxy server that relays all communication to the + endpoint passed to set_tunnel. This is done by sending a HTTP CONNECT + request to the proxy server when the connection is established. + + This method must be called before the HTTP connection has been + established. + + The headers argument should be a mapping of extra HTTP headers + to send with the CONNECT request. + """ + # Verify if this is required. + if self.sock: + raise RuntimeError("Can't setup tunnel for established connection.") + + self._tunnel_host, self._tunnel_port = self._get_hostport(host, port) + if headers: + self._tunnel_headers = headers + else: + self._tunnel_headers.clear() + + def _get_hostport(self, host, port): + if port is None: + i = host.rfind(':') + j = host.rfind(']') # ipv6 addresses have [...] + if i > j: + try: + port = int(host[i+1:]) + except ValueError: + if host[i+1:] == "": # http://foo.com:/ == http://foo.com/ + port = self.default_port + else: + raise InvalidURL("nonnumeric port: '%s'" % host[i+1:]) + host = host[:i] + else: + port = self.default_port + if host and host[0] == '[' and host[-1] == ']': + host = host[1:-1] + return (host, port) + + def set_debuglevel(self, level): + self.debuglevel = level + + def _tunnel(self): + self.send("CONNECT %s:%d HTTP/1.0\r\n" % (self._tunnel_host, + self._tunnel_port)) + for header, value in self._tunnel_headers.iteritems(): + self.send("%s: %s\r\n" % (header, value)) + self.send("\r\n") + response = self.response_class(self.sock, strict = self.strict, + method = self._method) + (version, code, message) = response._read_status() + + if version == "HTTP/0.9": + # HTTP/0.9 doesn't support the CONNECT verb, so if httplib has + # concluded HTTP/0.9 is being used something has gone wrong. + self.close() + raise socket.error("Invalid response from tunnel request") + if code != 200: + self.close() + raise socket.error("Tunnel connection failed: %d %s" % (code, + message.strip())) + while True: + line = response.fp.readline(_MAXLINE + 1) + if len(line) > _MAXLINE: + raise LineTooLong("header line") + if not line: + # for sites which EOF without sending trailer + break + if line == '\r\n': + break + + + def connect(self): + """Connect to the host and port specified in __init__.""" + self.sock = self._create_connection((self.host,self.port), + self.timeout, self.source_address) + + if self._tunnel_host: + self._tunnel() + + def close(self): + """Close the connection to the HTTP server.""" + self.__state = _CS_IDLE + try: + sock = self.sock + if sock: + self.sock = None + sock.close() # close it manually... there may be other refs + finally: + response = self.__response + if response: + self.__response = None + response.close() + + def send(self, data): + """Send `data' to the server.""" + if self.sock is None: + if self.auto_open: + self.connect() + else: + raise NotConnected() + + if self.debuglevel > 0: + print "send:", repr(data) + blocksize = 8192 + if hasattr(data,'read') and not isinstance(data, array): + if self.debuglevel > 0: print "sendIng a read()able" + datablock = data.read(blocksize) + while datablock: + self.sock.sendall(datablock) + datablock = data.read(blocksize) + else: + self.sock.sendall(data) + + def _output(self, s): + """Add a line of output to the current request buffer. + + Assumes that the line does *not* end with \\r\\n. + """ + self._buffer.append(s) + + def _send_output(self, message_body=None): + """Send the currently buffered request and clear the buffer. + + Appends an extra \\r\\n to the buffer. + A message_body may be specified, to be appended to the request. + """ + self._buffer.extend(("", "")) + msg = "\r\n".join(self._buffer) + del self._buffer[:] + # If msg and message_body are sent in a single send() call, + # it will avoid performance problems caused by the interaction + # between delayed ack and the Nagle algorithm. + if isinstance(message_body, str): + msg += message_body + message_body = None + self.send(msg) + if message_body is not None: + #message_body was not a string (i.e. it is a file) and + #we must run the risk of Nagle + self.send(message_body) + + def putrequest(self, method, url, skip_host=0, skip_accept_encoding=0): + """Send a request to the server. + + `method' specifies an HTTP request method, e.g. 'GET'. + `url' specifies the object being requested, e.g. '/index.html'. + `skip_host' if True does not add automatically a 'Host:' header + `skip_accept_encoding' if True does not add automatically an + 'Accept-Encoding:' header + """ + + # if a prior response has been completed, then forget about it. + if self.__response and self.__response.isclosed(): + self.__response = None + + + # in certain cases, we cannot issue another request on this connection. + # this occurs when: + # 1) we are in the process of sending a request. (_CS_REQ_STARTED) + # 2) a response to a previous request has signalled that it is going + # to close the connection upon completion. + # 3) the headers for the previous response have not been read, thus + # we cannot determine whether point (2) is true. (_CS_REQ_SENT) + # + # if there is no prior response, then we can request at will. + # + # if point (2) is true, then we will have passed the socket to the + # response (effectively meaning, "there is no prior response"), and + # will open a new one when a new request is made. + # + # Note: if a prior response exists, then we *can* start a new request. + # We are not allowed to begin fetching the response to this new + # request, however, until that prior response is complete. + # + if self.__state == _CS_IDLE: + self.__state = _CS_REQ_STARTED + else: + raise CannotSendRequest() + + # Save the method we use, we need it later in the response phase + self._method = method + if not url: + url = '/' + hdr = '%s %s %s' % (method, url, self._http_vsn_str) + + self._output(hdr) + + if self._http_vsn == 11: + # Issue some standard headers for better HTTP/1.1 compliance + + if not skip_host: + # this header is issued *only* for HTTP/1.1 + # connections. more specifically, this means it is + # only issued when the client uses the new + # HTTPConnection() class. backwards-compat clients + # will be using HTTP/1.0 and those clients may be + # issuing this header themselves. we should NOT issue + # it twice; some web servers (such as Apache) barf + # when they see two Host: headers + + # If we need a non-standard port,include it in the + # header. If the request is going through a proxy, + # but the host of the actual URL, not the host of the + # proxy. + + netloc = '' + if url.startswith('http'): + nil, netloc, nil, nil, nil = urlsplit(url) + + if netloc: + try: + netloc_enc = netloc.encode("ascii") + except UnicodeEncodeError: + netloc_enc = netloc.encode("idna") + self.putheader('Host', netloc_enc) + else: + if self._tunnel_host: + host = self._tunnel_host + port = self._tunnel_port + else: + host = self.host + port = self.port + + try: + host_enc = host.encode("ascii") + except UnicodeEncodeError: + host_enc = host.encode("idna") + # Wrap the IPv6 Host Header with [] (RFC 2732) + if host_enc.find(':') >= 0: + host_enc = "[" + host_enc + "]" + if port == self.default_port: + self.putheader('Host', host_enc) + else: + self.putheader('Host', "%s:%s" % (host_enc, port)) + + # note: we are assuming that clients will not attempt to set these + # headers since *this* library must deal with the + # consequences. this also means that when the supporting + # libraries are updated to recognize other forms, then this + # code should be changed (removed or updated). + + # we only want a Content-Encoding of "identity" since we don't + # support encodings such as x-gzip or x-deflate. + if not skip_accept_encoding: + self.putheader('Accept-Encoding', 'identity') + + # we can accept "chunked" Transfer-Encodings, but no others + # NOTE: no TE header implies *only* "chunked" + #self.putheader('TE', 'chunked') + + # if TE is supplied in the header, then it must appear in a + # Connection header. + #self.putheader('Connection', 'TE') + + else: + # For HTTP/1.0, the server will assume "not chunked" + pass + + def putheader(self, header, *values): + """Send a request header line to the server. + + For example: h.putheader('Accept', 'text/html') + """ + if self.__state != _CS_REQ_STARTED: + raise CannotSendHeader() + + header = '%s' % header + if not _is_legal_header_name(header): + raise ValueError('Invalid header name %r' % (header,)) + + values = [str(v) for v in values] + for one_value in values: + if _is_illegal_header_value(one_value): + raise ValueError('Invalid header value %r' % (one_value,)) + + hdr = '%s: %s' % (header, '\r\n\t'.join(values)) + self._output(hdr) + + def endheaders(self, message_body=None): + """Indicate that the last header line has been sent to the server. + + This method sends the request to the server. The optional + message_body argument can be used to pass a message body + associated with the request. The message body will be sent in + the same packet as the message headers if it is string, otherwise it is + sent as a separate packet. + """ + if self.__state == _CS_REQ_STARTED: + self.__state = _CS_REQ_SENT + else: + raise CannotSendHeader() + self._send_output(message_body) + + def request(self, method, url, body=None, headers={}): + """Send a complete request to the server.""" + self._send_request(method, url, body, headers) + + def _set_content_length(self, body, method): + # Set the content-length based on the body. If the body is "empty", we + # set Content-Length: 0 for methods that expect a body (RFC 7230, + # Section 3.3.2). If the body is set for other methods, we set the + # header provided we can figure out what the length is. + thelen = None + if body is None and method.upper() in _METHODS_EXPECTING_BODY: + thelen = '0' + elif body is not None: + try: + thelen = str(len(body)) + except (TypeError, AttributeError): + # If this is a file-like object, try to + # fstat its file descriptor + try: + thelen = str(os.fstat(body.fileno()).st_size) + except (AttributeError, OSError): + # Don't send a length if this failed + if self.debuglevel > 0: print "Cannot stat!!" + + if thelen is not None: + self.putheader('Content-Length', thelen) + + def _send_request(self, method, url, body, headers): + # Honor explicitly requested Host: and Accept-Encoding: headers. + header_names = dict.fromkeys([k.lower() for k in headers]) + skips = {} + if 'host' in header_names: + skips['skip_host'] = 1 + if 'accept-encoding' in header_names: + skips['skip_accept_encoding'] = 1 + + self.putrequest(method, url, **skips) + + if 'content-length' not in header_names: + self._set_content_length(body, method) + for hdr, value in headers.iteritems(): + self.putheader(hdr, value) + self.endheaders(body) + + def getresponse(self, buffering=False): + "Get the response from the server." + + # if a prior response has been completed, then forget about it. + if self.__response and self.__response.isclosed(): + self.__response = None + + # + # if a prior response exists, then it must be completed (otherwise, we + # cannot read this response's header to determine the connection-close + # behavior) + # + # note: if a prior response existed, but was connection-close, then the + # socket and response were made independent of this HTTPConnection + # object since a new request requires that we open a whole new + # connection + # + # this means the prior response had one of two states: + # 1) will_close: this connection was reset and the prior socket and + # response operate independently + # 2) persistent: the response was retained and we await its + # isclosed() status to become true. + # + if self.__state != _CS_REQ_SENT or self.__response: + raise ResponseNotReady() + + args = (self.sock,) + kwds = {"strict":self.strict, "method":self._method} + if self.debuglevel > 0: + args += (self.debuglevel,) + if buffering: + #only add this keyword if non-default, for compatibility with + #other response_classes. + kwds["buffering"] = True; + response = self.response_class(*args, **kwds) + + try: + response.begin() + assert response.will_close != _UNKNOWN + self.__state = _CS_IDLE + + if response.will_close: + # this effectively passes the connection to the response + self.close() + else: + # remember this, so we can tell when it is complete + self.__response = response + + return response + except: + response.close() + raise + + +class HTTP: + "Compatibility class with httplib.py from 1.5." + + _http_vsn = 10 + _http_vsn_str = 'HTTP/1.0' + + debuglevel = 0 + + _connection_class = HTTPConnection + + def __init__(self, host='', port=None, strict=None): + "Provide a default host, since the superclass requires one." + + # some joker passed 0 explicitly, meaning default port + if port == 0: + port = None + + # Note that we may pass an empty string as the host; this will raise + # an error when we attempt to connect. Presumably, the client code + # will call connect before then, with a proper host. + self._setup(self._connection_class(host, port, strict)) + + def _setup(self, conn): + self._conn = conn + + # set up delegation to flesh out interface + self.send = conn.send + self.putrequest = conn.putrequest + self.putheader = conn.putheader + self.endheaders = conn.endheaders + self.set_debuglevel = conn.set_debuglevel + + conn._http_vsn = self._http_vsn + conn._http_vsn_str = self._http_vsn_str + + self.file = None + + def connect(self, host=None, port=None): + "Accept arguments to set the host/port, since the superclass doesn't." + + if host is not None: + (self._conn.host, self._conn.port) = self._conn._get_hostport(host, port) + self._conn.connect() + + def getfile(self): + "Provide a getfile, since the superclass' does not use this concept." + return self.file + + def getreply(self, buffering=False): + """Compat definition since superclass does not define it. + + Returns a tuple consisting of: + - server status code (e.g. '200' if all goes well) + - server "reason" corresponding to status code + - any RFC822 headers in the response from the server + """ + try: + if not buffering: + response = self._conn.getresponse() + else: + #only add this keyword if non-default for compatibility + #with other connection classes + response = self._conn.getresponse(buffering) + except BadStatusLine, e: + ### hmm. if getresponse() ever closes the socket on a bad request, + ### then we are going to have problems with self.sock + + ### should we keep this behavior? do people use it? + # keep the socket open (as a file), and return it + self.file = self._conn.sock.makefile('rb', 0) + + # close our socket -- we want to restart after any protocol error + self.close() + + self.headers = None + return -1, e.line, None + + self.headers = response.msg + self.file = response.fp + return response.status, response.reason, response.msg + + def close(self): + self._conn.close() + + # note that self.file == response.fp, which gets closed by the + # superclass. just clear the object ref here. + ### hmm. messy. if status==-1, then self.file is owned by us. + ### well... we aren't explicitly closing, but losing this ref will + ### do it + self.file = None + +try: + import ssl +except ImportError: + pass +else: + class HTTPSConnection(HTTPConnection): + "This class allows communication via SSL." + + default_port = HTTPS_PORT + + def __init__(self, host, port=None, key_file=None, cert_file=None, + strict=None, timeout=socket._GLOBAL_DEFAULT_TIMEOUT, + source_address=None, context=None): + HTTPConnection.__init__(self, host, port, strict, timeout, + source_address) + self.key_file = key_file + self.cert_file = cert_file + if context is None: + context = ssl._create_default_https_context() + if key_file or cert_file: + context.load_cert_chain(cert_file, key_file) + self._context = context + + def connect(self): + "Connect to a host on a given (SSL) port." + + HTTPConnection.connect(self) + + if self._tunnel_host: + server_hostname = self._tunnel_host + else: + server_hostname = self.host + + self.sock = self._context.wrap_socket(self.sock, + server_hostname=server_hostname) + + __all__.append("HTTPSConnection") + + class HTTPS(HTTP): + """Compatibility with 1.5 httplib interface + + Python 1.5.2 did not have an HTTPS class, but it defined an + interface for sending http requests that is also useful for + https. + """ + + _connection_class = HTTPSConnection + + def __init__(self, host='', port=None, key_file=None, cert_file=None, + strict=None, context=None): + # provide a default host, pass the X509 cert info + + # urf. compensate for bad input. + if port == 0: + port = None + self._setup(self._connection_class(host, port, key_file, + cert_file, strict, + context=context)) + + # we never actually use these for anything, but we keep them + # here for compatibility with post-1.5.2 CVS. + self.key_file = key_file + self.cert_file = cert_file + + + def FakeSocket (sock, sslobj): + warnings.warn("FakeSocket is deprecated, and won't be in 3.x. " + + "Use the result of ssl.wrap_socket() directly instead.", + DeprecationWarning, stacklevel=2) + return sslobj + + +class HTTPException(Exception): + # Subclasses that define an __init__ must call Exception.__init__ + # or define self.args. Otherwise, str() will fail. + pass + +class NotConnected(HTTPException): + pass + +class InvalidURL(HTTPException): + pass + +class UnknownProtocol(HTTPException): + def __init__(self, version): + self.args = version, + self.version = version + +class UnknownTransferEncoding(HTTPException): + pass + +class UnimplementedFileMode(HTTPException): + pass + +class IncompleteRead(HTTPException): + def __init__(self, partial, expected=None): + self.args = partial, + self.partial = partial + self.expected = expected + def __repr__(self): + if self.expected is not None: + e = ', %i more expected' % self.expected + else: + e = '' + return 'IncompleteRead(%i bytes read%s)' % (len(self.partial), e) + def __str__(self): + return repr(self) + +class ImproperConnectionState(HTTPException): + pass + +class CannotSendRequest(ImproperConnectionState): + pass + +class CannotSendHeader(ImproperConnectionState): + pass + +class ResponseNotReady(ImproperConnectionState): + pass + +class BadStatusLine(HTTPException): + def __init__(self, line): + if not line: + line = repr(line) + self.args = line, + self.line = line + +class LineTooLong(HTTPException): + def __init__(self, line_type): + HTTPException.__init__(self, "got more than %d bytes when reading %s" + % (_MAXLINE, line_type)) + +# for backwards compatibility +error = HTTPException + +class LineAndFileWrapper: + """A limited file-like object for HTTP/0.9 responses.""" + + # The status-line parsing code calls readline(), which normally + # get the HTTP status line. For a 0.9 response, however, this is + # actually the first line of the body! Clients need to get a + # readable file object that contains that line. + + def __init__(self, line, file): + self._line = line + self._file = file + self._line_consumed = 0 + self._line_offset = 0 + self._line_left = len(line) + + def __getattr__(self, attr): + return getattr(self._file, attr) + + def _done(self): + # called when the last byte is read from the line. After the + # call, all read methods are delegated to the underlying file + # object. + self._line_consumed = 1 + self.read = self._file.read + self.readline = self._file.readline + self.readlines = self._file.readlines + + def read(self, amt=None): + if self._line_consumed: + return self._file.read(amt) + assert self._line_left + if amt is None or amt > self._line_left: + s = self._line[self._line_offset:] + self._done() + if amt is None: + return s + self._file.read() + else: + return s + self._file.read(amt - len(s)) + else: + assert amt <= self._line_left + i = self._line_offset + j = i + amt + s = self._line[i:j] + self._line_offset = j + self._line_left -= amt + if self._line_left == 0: + self._done() + return s + + def readline(self): + if self._line_consumed: + return self._file.readline() + assert self._line_left + s = self._line[self._line_offset:] + self._done() + return s + + def readlines(self, size=None): + if self._line_consumed: + return self._file.readlines(size) + assert self._line_left + L = [self._line[self._line_offset:]] + self._done() + if size is None: + return L + self._file.readlines() + else: + return L + self._file.readlines(size) diff --git a/io.py b/io.py new file mode 100644 index 0000000..1438493 --- /dev/null +++ b/io.py @@ -0,0 +1,90 @@ +"""The io module provides the Python interfaces to stream handling. The +builtin open function is defined in this module. + +At the top of the I/O hierarchy is the abstract base class IOBase. It +defines the basic interface to a stream. Note, however, that there is no +separation between reading and writing to streams; implementations are +allowed to raise an IOError if they do not support a given operation. + +Extending IOBase is RawIOBase which deals simply with the reading and +writing of raw bytes to a stream. FileIO subclasses RawIOBase to provide +an interface to OS files. + +BufferedIOBase deals with buffering on a raw byte stream (RawIOBase). Its +subclasses, BufferedWriter, BufferedReader, and BufferedRWPair buffer +streams that are readable, writable, and both respectively. +BufferedRandom provides a buffered interface to random access +streams. BytesIO is a simple stream of in-memory bytes. + +Another IOBase subclass, TextIOBase, deals with the encoding and decoding +of streams into text. TextIOWrapper, which extends it, is a buffered text +interface to a buffered raw stream (`BufferedIOBase`). Finally, StringIO +is a in-memory stream for text. + +Argument names are not part of the specification, and only the arguments +of open() are intended to be used as keyword arguments. + +data: + +DEFAULT_BUFFER_SIZE + + An int containing the default buffer size used by the module's buffered + I/O classes. open() uses the file's blksize (as obtained by os.stat) if + possible. +""" +# New I/O library conforming to PEP 3116. + +__author__ = ("Guido van Rossum , " + "Mike Verdone , " + "Mark Russell , " + "Antoine Pitrou , " + "Amaury Forgeot d'Arc , " + "Benjamin Peterson ") + +__all__ = ["BlockingIOError", "open", "IOBase", "RawIOBase", "FileIO", + "BytesIO", "StringIO", "BufferedIOBase", + "BufferedReader", "BufferedWriter", "BufferedRWPair", + "BufferedRandom", "TextIOBase", "TextIOWrapper", + "UnsupportedOperation", "SEEK_SET", "SEEK_CUR", "SEEK_END"] + + +import _io +import abc + +from _io import (DEFAULT_BUFFER_SIZE, BlockingIOError, UnsupportedOperation, + open, FileIO, BytesIO, StringIO, BufferedReader, + BufferedWriter, BufferedRWPair, BufferedRandom, + IncrementalNewlineDecoder, TextIOWrapper) + +OpenWrapper = _io.open # for compatibility with _pyio + +# for seek() +SEEK_SET = 0 +SEEK_CUR = 1 +SEEK_END = 2 + +# Declaring ABCs in C is tricky so we do it here. +# Method descriptions and default implementations are inherited from the C +# version however. +class IOBase(_io._IOBase): + __metaclass__ = abc.ABCMeta + __doc__ = _io._IOBase.__doc__ + +class RawIOBase(_io._RawIOBase, IOBase): + __doc__ = _io._RawIOBase.__doc__ + +class BufferedIOBase(_io._BufferedIOBase, IOBase): + __doc__ = _io._BufferedIOBase.__doc__ + +class TextIOBase(_io._TextIOBase, IOBase): + __doc__ = _io._TextIOBase.__doc__ + +RawIOBase.register(FileIO) + +for klass in (BytesIO, BufferedReader, BufferedWriter, BufferedRandom, + BufferedRWPair): + BufferedIOBase.register(klass) + +for klass in (StringIO, TextIOWrapper): + TextIOBase.register(klass) +del klass diff --git a/ipypulldom.py b/ipypulldom.py new file mode 100644 index 0000000..4aba1b5 --- /dev/null +++ b/ipypulldom.py @@ -0,0 +1,66 @@ +##################################################################################### +# +# Copyright (c) Harry Pierson. All rights reserved. +# +# This source code is subject to terms and conditions of the Microsoft Public License. +# A copy of the license can be found at http://opensource.org/licenses/ms-pl.html +# By using this source code in any fashion, you are agreeing to be bound +# by the terms of the Microsoft Public License. +# +# You must not remove this notice, or any other, from this software. +# +##################################################################################### + +import clr +clr.AddReference('System.Xml') + +from System.String import IsNullOrEmpty +from System.Xml import XmlReader, XmlNodeType, XmlReaderSettings, DtdProcessing +from System.IO import StringReader + +class XmlNode(object): + def __init__(self, xr): + self.name = xr.LocalName + self.namespace = xr.NamespaceURI + self.prefix = xr.Prefix + self.value = xr.Value + self.nodeType = xr.NodeType + + if xr.NodeType == XmlNodeType.Element: + self.attributes = [] + while xr.MoveToNextAttribute(): + if xr.NamespaceURI == 'http://www.w3.org/2000/xmlns/': + continue + self.attributes.append(XmlNode(xr)) + xr.MoveToElement() + + @property + def xname(self): + if IsNullOrEmpty(self.namespace): + return self.name + return "{%(namespace)s}%(name)s" % self.__dict__ + + +def parse(xml): + # see issue 379, and https://stackoverflow.com/questions/215854/ + settings = XmlReaderSettings(); + settings.XmlResolver = None; + settings.DtdProcessing = DtdProcessing.Ignore; + settings.ProhibitDtd = False; + + with XmlReader.Create(xml, settings) as xr: + while xr.Read(): + xr.MoveToContent() + node = XmlNode(xr) + yield node + if xr.IsEmptyElement: + node.nodeType = XmlNodeType.EndElement + del node.attributes + yield node + +def parseString(xml): + return parse(StringReader(xml)) + +if __name__ == "__main__": + nodes = parse('http://feeds.feedburner.com/Devhawk') + \ No newline at end of file diff --git a/keyword.py b/keyword.py new file mode 100644 index 0000000..69794bd --- /dev/null +++ b/keyword.py @@ -0,0 +1,93 @@ +#! /usr/bin/env python + +"""Keywords (from "graminit.c") + +This file is automatically generated; please don't muck it up! + +To update the symbols in this file, 'cd' to the top directory of +the python source tree after building the interpreter and run: + + ./python Lib/keyword.py +""" + +__all__ = ["iskeyword", "kwlist"] + +kwlist = [ +#--start keywords-- + 'and', + 'as', + 'assert', + 'break', + 'class', + 'continue', + 'def', + 'del', + 'elif', + 'else', + 'except', + 'exec', + 'finally', + 'for', + 'from', + 'global', + 'if', + 'import', + 'in', + 'is', + 'lambda', + 'not', + 'or', + 'pass', + 'print', + 'raise', + 'return', + 'try', + 'while', + 'with', + 'yield', +#--end keywords-- + ] + +iskeyword = frozenset(kwlist).__contains__ + +def main(): + import sys, re + + args = sys.argv[1:] + iptfile = args and args[0] or "Python/graminit.c" + if len(args) > 1: optfile = args[1] + else: optfile = "Lib/keyword.py" + + # scan the source file for keywords + fp = open(iptfile) + strprog = re.compile('"([^"]+)"') + lines = [] + for line in fp: + if '{1, "' in line: + match = strprog.search(line) + if match: + lines.append(" '" + match.group(1) + "',\n") + fp.close() + lines.sort() + + # load the output skeleton from the target + fp = open(optfile) + format = fp.readlines() + fp.close() + + # insert the lines of keywords + try: + start = format.index("#--start keywords--\n") + 1 + end = format.index("#--end keywords--\n") + format[start:end] = lines + except ValueError: + sys.stderr.write("target does not contain format markers\n") + sys.exit(1) + + # write the output file + fp = open(optfile, 'w') + fp.write(''.join(format)) + fp.close() + +if __name__ == "__main__": + main() diff --git a/linecache.py b/linecache.py new file mode 100644 index 0000000..4b97be3 --- /dev/null +++ b/linecache.py @@ -0,0 +1,139 @@ +"""Cache lines from files. + +This is intended to read lines from modules imported -- hence if a filename +is not found, it will look down the module search path for a file by +that name. +""" + +import sys +import os + +__all__ = ["getline", "clearcache", "checkcache"] + +def getline(filename, lineno, module_globals=None): + lines = getlines(filename, module_globals) + if 1 <= lineno <= len(lines): + return lines[lineno-1] + else: + return '' + + +# The cache + +cache = {} # The cache + + +def clearcache(): + """Clear the cache entirely.""" + + global cache + cache = {} + + +def getlines(filename, module_globals=None): + """Get the lines for a file from the cache. + Update the cache if it doesn't contain an entry for this file already.""" + + if filename in cache: + return cache[filename][2] + + try: + return updatecache(filename, module_globals) + except MemoryError: + clearcache() + return [] + + +def checkcache(filename=None): + """Discard cache entries that are out of date. + (This is not checked upon each call!)""" + + if filename is None: + filenames = cache.keys() + else: + if filename in cache: + filenames = [filename] + else: + return + + for filename in filenames: + size, mtime, lines, fullname = cache[filename] + if mtime is None: + continue # no-op for files loaded via a __loader__ + try: + stat = os.stat(fullname) + except os.error: + del cache[filename] + continue + if size != stat.st_size or mtime != stat.st_mtime: + del cache[filename] + + +def updatecache(filename, module_globals=None): + """Update a cache entry and return its list of lines. + If something's wrong, print a message, discard the cache entry, + and return an empty list.""" + + if filename in cache: + del cache[filename] + if not filename or (filename.startswith('<') and filename.endswith('>')): + return [] + + fullname = filename + try: + stat = os.stat(fullname) + except OSError: + basename = filename + + # Try for a __loader__, if available + if module_globals and '__loader__' in module_globals: + name = module_globals.get('__name__') + loader = module_globals['__loader__'] + get_source = getattr(loader, 'get_source', None) + + if name and get_source: + try: + data = get_source(name) + except (ImportError, IOError): + pass + else: + if data is None: + # No luck, the PEP302 loader cannot find the source + # for this module. + return [] + cache[filename] = ( + len(data), None, + [line+'\n' for line in data.splitlines()], fullname + ) + return cache[filename][2] + + # Try looking through the module search path, which is only useful + # when handling a relative filename. + if os.path.isabs(filename): + return [] + + for dirname in sys.path: + # When using imputil, sys.path may contain things other than + # strings; ignore them when it happens. + try: + fullname = os.path.join(dirname, basename) + except (TypeError, AttributeError): + # Not sufficiently string-like to do anything useful with. + continue + try: + stat = os.stat(fullname) + break + except os.error: + pass + else: + return [] + try: + with open(fullname, 'rU') as fp: + lines = fp.readlines() + except IOError: + return [] + if lines and not lines[-1].endswith('\n'): + lines[-1] += '\n' + size, mtime = stat.st_size, stat.st_mtime + cache[filename] = size, mtime, lines, fullname + return lines diff --git a/mimetools.py b/mimetools.py new file mode 100644 index 0000000..71ca8f8 --- /dev/null +++ b/mimetools.py @@ -0,0 +1,250 @@ +"""Various tools used by MIME-reading or MIME-writing programs.""" + + +import os +import sys +import tempfile +from warnings import filterwarnings, catch_warnings +with catch_warnings(): + if sys.py3kwarning: + filterwarnings("ignore", ".*rfc822 has been removed", DeprecationWarning) + import rfc822 + +from warnings import warnpy3k +warnpy3k("in 3.x, mimetools has been removed in favor of the email package", + stacklevel=2) + +__all__ = ["Message","choose_boundary","encode","decode","copyliteral", + "copybinary"] + +class Message(rfc822.Message): + """A derived class of rfc822.Message that knows about MIME headers and + contains some hooks for decoding encoded and multipart messages.""" + + def __init__(self, fp, seekable = 1): + rfc822.Message.__init__(self, fp, seekable) + self.encodingheader = \ + self.getheader('content-transfer-encoding') + self.typeheader = \ + self.getheader('content-type') + self.parsetype() + self.parseplist() + + def parsetype(self): + str = self.typeheader + if str is None: + str = 'text/plain' + if ';' in str: + i = str.index(';') + self.plisttext = str[i:] + str = str[:i] + else: + self.plisttext = '' + fields = str.split('/') + for i in range(len(fields)): + fields[i] = fields[i].strip().lower() + self.type = '/'.join(fields) + self.maintype = fields[0] + self.subtype = '/'.join(fields[1:]) + + def parseplist(self): + str = self.plisttext + self.plist = [] + while str[:1] == ';': + str = str[1:] + if ';' in str: + # XXX Should parse quotes! + end = str.index(';') + else: + end = len(str) + f = str[:end] + if '=' in f: + i = f.index('=') + f = f[:i].strip().lower() + \ + '=' + f[i+1:].strip() + self.plist.append(f.strip()) + str = str[end:] + + def getplist(self): + return self.plist + + def getparam(self, name): + name = name.lower() + '=' + n = len(name) + for p in self.plist: + if p[:n] == name: + return rfc822.unquote(p[n:]) + return None + + def getparamnames(self): + result = [] + for p in self.plist: + i = p.find('=') + if i >= 0: + result.append(p[:i].lower()) + return result + + def getencoding(self): + if self.encodingheader is None: + return '7bit' + return self.encodingheader.lower() + + def gettype(self): + return self.type + + def getmaintype(self): + return self.maintype + + def getsubtype(self): + return self.subtype + + + + +# Utility functions +# ----------------- + +try: + import thread +except ImportError: + import dummy_thread as thread +_counter_lock = thread.allocate_lock() +del thread + +_counter = 0 +def _get_next_counter(): + global _counter + _counter_lock.acquire() + _counter += 1 + result = _counter + _counter_lock.release() + return result + +_prefix = None + +def choose_boundary(): + """Return a string usable as a multipart boundary. + + The string chosen is unique within a single program run, and + incorporates the user id (if available), process id (if available), + and current time. So it's very unlikely the returned string appears + in message text, but there's no guarantee. + + The boundary contains dots so you have to quote it in the header.""" + + global _prefix + import time + if _prefix is None: + import socket + try: + hostid = socket.gethostbyname(socket.gethostname()) + except socket.gaierror: + hostid = '127.0.0.1' + try: + uid = repr(os.getuid()) + except AttributeError: + uid = '1' + try: + pid = repr(os.getpid()) + except AttributeError: + pid = '1' + _prefix = hostid + '.' + uid + '.' + pid + return "%s.%.3f.%d" % (_prefix, time.time(), _get_next_counter()) + + +# Subroutines for decoding some common content-transfer-types + +def decode(input, output, encoding): + """Decode common content-transfer-encodings (base64, quopri, uuencode).""" + if encoding == 'base64': + import base64 + return base64.decode(input, output) + if encoding == 'quoted-printable': + import quopri + return quopri.decode(input, output) + if encoding in ('uuencode', 'x-uuencode', 'uue', 'x-uue'): + import uu + return uu.decode(input, output) + if encoding in ('7bit', '8bit'): + return output.write(input.read()) + if encoding in decodetab: + pipethrough(input, decodetab[encoding], output) + else: + raise ValueError, \ + 'unknown Content-Transfer-Encoding: %s' % encoding + +def encode(input, output, encoding): + """Encode common content-transfer-encodings (base64, quopri, uuencode).""" + if encoding == 'base64': + import base64 + return base64.encode(input, output) + if encoding == 'quoted-printable': + import quopri + return quopri.encode(input, output, 0) + if encoding in ('uuencode', 'x-uuencode', 'uue', 'x-uue'): + import uu + return uu.encode(input, output) + if encoding in ('7bit', '8bit'): + return output.write(input.read()) + if encoding in encodetab: + pipethrough(input, encodetab[encoding], output) + else: + raise ValueError, \ + 'unknown Content-Transfer-Encoding: %s' % encoding + +# The following is no longer used for standard encodings + +# XXX This requires that uudecode and mmencode are in $PATH + +uudecode_pipe = '''( +TEMP=/tmp/@uu.$$ +sed "s%^begin [0-7][0-7]* .*%begin 600 $TEMP%" | uudecode +cat $TEMP +rm $TEMP +)''' + +decodetab = { + 'uuencode': uudecode_pipe, + 'x-uuencode': uudecode_pipe, + 'uue': uudecode_pipe, + 'x-uue': uudecode_pipe, + 'quoted-printable': 'mmencode -u -q', + 'base64': 'mmencode -u -b', +} + +encodetab = { + 'x-uuencode': 'uuencode tempfile', + 'uuencode': 'uuencode tempfile', + 'x-uue': 'uuencode tempfile', + 'uue': 'uuencode tempfile', + 'quoted-printable': 'mmencode -q', + 'base64': 'mmencode -b', +} + +def pipeto(input, command): + pipe = os.popen(command, 'w') + copyliteral(input, pipe) + pipe.close() + +def pipethrough(input, command, output): + (fd, tempname) = tempfile.mkstemp() + temp = os.fdopen(fd, 'w') + copyliteral(input, temp) + temp.close() + pipe = os.popen(command + ' <' + tempname, 'r') + copybinary(pipe, output) + pipe.close() + os.unlink(tempname) + +def copyliteral(input, output): + while 1: + line = input.readline() + if not line: break + output.write(line) + +def copybinary(input, output): + BUFSIZE = 8192 + while 1: + line = input.read(BUFSIZE) + if not line: break + output.write(line) diff --git a/ntpath.py b/ntpath.py new file mode 100644 index 0000000..58951b9 --- /dev/null +++ b/ntpath.py @@ -0,0 +1,550 @@ +# Module 'ntpath' -- common operations on WinNT/Win95 pathnames +"""Common pathname manipulations, WindowsNT/95 version. + +Instead of importing this module directly, import os and refer to this +module as os.path. +""" + +import os +import sys +import stat +import genericpath +import warnings + +from genericpath import * +from genericpath import _unicode + +__all__ = ["normcase","isabs","join","splitdrive","split","splitext", + "basename","dirname","commonprefix","getsize","getmtime", + "getatime","getctime", "islink","exists","lexists","isdir","isfile", + "ismount","walk","expanduser","expandvars","normpath","abspath", + "splitunc","curdir","pardir","sep","pathsep","defpath","altsep", + "extsep","devnull","realpath","supports_unicode_filenames","relpath"] + +# strings representing various path-related bits and pieces +curdir = '.' +pardir = '..' +extsep = '.' +sep = '\\' +pathsep = ';' +altsep = '/' +defpath = '.;C:\\bin' +if 'ce' in sys.builtin_module_names: + defpath = '\\Windows' +elif 'os2' in sys.builtin_module_names: + # OS/2 w/ VACPP + altsep = '/' +devnull = 'nul' + +# Normalize the case of a pathname and map slashes to backslashes. +# Other normalizations (such as optimizing '../' away) are not done +# (this is done by normpath). + +def normcase(s): + """Normalize case of pathname. + + Makes all characters lowercase and all slashes into backslashes.""" + return s.replace("/", "\\").lower() + + +# Return whether a path is absolute. +# Trivial in Posix, harder on the Mac or MS-DOS. +# For DOS it is absolute if it starts with a slash or backslash (current +# volume), or if a pathname after the volume letter and colon / UNC resource +# starts with a slash or backslash. + +def isabs(s): + """Test whether a path is absolute""" + s = splitdrive(s)[1] + return s != '' and s[:1] in '/\\' + + +# Join two (or more) paths. +def join(path, *paths): + """Join two or more pathname components, inserting "\\" as needed.""" + result_drive, result_path = splitdrive(path) + for p in paths: + p_drive, p_path = splitdrive(p) + if p_path and p_path[0] in '\\/': + # Second path is absolute + if p_drive or not result_drive: + result_drive = p_drive + result_path = p_path + continue + elif p_drive and p_drive != result_drive: + if p_drive.lower() != result_drive.lower(): + # Different drives => ignore the first path entirely + result_drive = p_drive + result_path = p_path + continue + # Same drive in different case + result_drive = p_drive + # Second path is relative to the first + if result_path and result_path[-1] not in '\\/': + result_path = result_path + '\\' + result_path = result_path + p_path + ## add separator between UNC and non-absolute path + if (result_path and result_path[0] not in '\\/' and + result_drive and result_drive[-1:] != ':'): + return result_drive + sep + result_path + return result_drive + result_path + + +# Split a path in a drive specification (a drive letter followed by a +# colon) and the path specification. +# It is always true that drivespec + pathspec == p +def splitdrive(p): + """Split a pathname into drive/UNC sharepoint and relative path specifiers. + Returns a 2-tuple (drive_or_unc, path); either part may be empty. + + If you assign + result = splitdrive(p) + It is always true that: + result[0] + result[1] == p + + If the path contained a drive letter, drive_or_unc will contain everything + up to and including the colon. e.g. splitdrive("c:/dir") returns ("c:", "/dir") + + If the path contained a UNC path, the drive_or_unc will contain the host name + and share up to but not including the fourth directory separator character. + e.g. splitdrive("//host/computer/dir") returns ("//host/computer", "/dir") + + Paths cannot contain both a drive letter and a UNC path. + + """ + if len(p) > 1: + normp = p.replace(altsep, sep) + if (normp[0:2] == sep*2) and (normp[2:3] != sep): + # is a UNC path: + # vvvvvvvvvvvvvvvvvvvv drive letter or UNC path + # \\machine\mountpoint\directory\etc\... + # directory ^^^^^^^^^^^^^^^ + index = normp.find(sep, 2) + if index == -1: + return '', p + index2 = normp.find(sep, index + 1) + # a UNC path can't have two slashes in a row + # (after the initial two) + if index2 == index + 1: + return '', p + if index2 == -1: + index2 = len(p) + return p[:index2], p[index2:] + if normp[1] == ':': + return p[:2], p[2:] + return '', p + +# Parse UNC paths +def splitunc(p): + """Split a pathname into UNC mount point and relative path specifiers. + + Return a 2-tuple (unc, rest); either part may be empty. + If unc is not empty, it has the form '//host/mount' (or similar + using backslashes). unc+rest is always the input path. + Paths containing drive letters never have an UNC part. + """ + if p[1:2] == ':': + return '', p # Drive letter present + firstTwo = p[0:2] + if firstTwo == '//' or firstTwo == '\\\\': + # is a UNC path: + # vvvvvvvvvvvvvvvvvvvv equivalent to drive letter + # \\machine\mountpoint\directories... + # directory ^^^^^^^^^^^^^^^ + normp = p.replace('\\', '/') + index = normp.find('/', 2) + if index <= 2: + return '', p + index2 = normp.find('/', index + 1) + # a UNC path can't have two slashes in a row + # (after the initial two) + if index2 == index + 1: + return '', p + if index2 == -1: + index2 = len(p) + return p[:index2], p[index2:] + return '', p + + +# Split a path in head (everything up to the last '/') and tail (the +# rest). After the trailing '/' is stripped, the invariant +# join(head, tail) == p holds. +# The resulting head won't end in '/' unless it is the root. + +def split(p): + """Split a pathname. + + Return tuple (head, tail) where tail is everything after the final slash. + Either part may be empty.""" + + d, p = splitdrive(p) + # set i to index beyond p's last slash + i = len(p) + while i and p[i-1] not in '/\\': + i = i - 1 + head, tail = p[:i], p[i:] # now tail has no slashes + # remove trailing slashes from head, unless it's all slashes + head2 = head + while head2 and head2[-1] in '/\\': + head2 = head2[:-1] + head = head2 or head + return d + head, tail + + +# Split a path in root and extension. +# The extension is everything starting at the last dot in the last +# pathname component; the root is everything before that. +# It is always true that root + ext == p. + +def splitext(p): + return genericpath._splitext(p, sep, altsep, extsep) +splitext.__doc__ = genericpath._splitext.__doc__ + + +# Return the tail (basename) part of a path. + +def basename(p): + """Returns the final component of a pathname""" + return split(p)[1] + + +# Return the head (dirname) part of a path. + +def dirname(p): + """Returns the directory component of a pathname""" + return split(p)[0] + +# Is a path a symbolic link? +# This will always return false on systems where posix.lstat doesn't exist. + +def islink(path): + """Test for symbolic link. + On WindowsNT/95 and OS/2 always returns false + """ + return False + +# alias exists to lexists +lexists = exists + +# Is a path a mount point? Either a root (with or without drive letter) +# or an UNC path with at most a / or \ after the mount point. + +def ismount(path): + """Test whether a path is a mount point (defined as root of drive)""" + unc, rest = splitunc(path) + if unc: + return rest in ("", "/", "\\") + p = splitdrive(path)[1] + return len(p) == 1 and p[0] in '/\\' + + +# Directory tree walk. +# For each directory under top (including top itself, but excluding +# '.' and '..'), func(arg, dirname, filenames) is called, where +# dirname is the name of the directory and filenames is the list +# of files (and subdirectories etc.) in the directory. +# The func may modify the filenames list, to implement a filter, +# or to impose a different order of visiting. + +def walk(top, func, arg): + """Directory tree walk with callback function. + + For each directory in the directory tree rooted at top (including top + itself, but excluding '.' and '..'), call func(arg, dirname, fnames). + dirname is the name of the directory, and fnames a list of the names of + the files and subdirectories in dirname (excluding '.' and '..'). func + may modify the fnames list in-place (e.g. via del or slice assignment), + and walk will only recurse into the subdirectories whose names remain in + fnames; this can be used to implement a filter, or to impose a specific + order of visiting. No semantics are defined for, or required of, arg, + beyond that arg is always passed to func. It can be used, e.g., to pass + a filename pattern, or a mutable object designed to accumulate + statistics. Passing None for arg is common.""" + warnings.warnpy3k("In 3.x, os.path.walk is removed in favor of os.walk.", + stacklevel=2) + try: + names = os.listdir(top) + except os.error: + return + func(arg, top, names) + for name in names: + name = join(top, name) + if isdir(name): + walk(name, func, arg) + + +# Expand paths beginning with '~' or '~user'. +# '~' means $HOME; '~user' means that user's home directory. +# If the path doesn't begin with '~', or if the user or $HOME is unknown, +# the path is returned unchanged (leaving error reporting to whatever +# function is called with the expanded path as argument). +# See also module 'glob' for expansion of *, ? and [...] in pathnames. +# (A function should also be defined to do full *sh-style environment +# variable expansion.) + +def expanduser(path): + """Expand ~ and ~user constructs. + + If user or $HOME is unknown, do nothing.""" + if path[:1] != '~': + return path + i, n = 1, len(path) + while i < n and path[i] not in '/\\': + i = i + 1 + + if 'HOME' in os.environ: + userhome = os.environ['HOME'] + elif 'USERPROFILE' in os.environ: + userhome = os.environ['USERPROFILE'] + elif not 'HOMEPATH' in os.environ: + return path + else: + try: + drive = os.environ['HOMEDRIVE'] + except KeyError: + drive = '' + userhome = join(drive, os.environ['HOMEPATH']) + + if i != 1: #~user + userhome = join(dirname(userhome), path[1:i]) + + return userhome + path[i:] + + +# Expand paths containing shell variable substitutions. +# The following rules apply: +# - no expansion within single quotes +# - '$$' is translated into '$' +# - '%%' is translated into '%' if '%%' are not seen in %var1%%var2% +# - ${varname} is accepted. +# - $varname is accepted. +# - %varname% is accepted. +# - varnames can be made out of letters, digits and the characters '_-' +# (though is not verified in the ${varname} and %varname% cases) +# XXX With COMMAND.COM you can use any characters in a variable name, +# XXX except '^|<>='. + +def expandvars(path): + """Expand shell variables of the forms $var, ${var} and %var%. + + Unknown variables are left unchanged.""" + if '$' not in path and '%' not in path: + return path + import string + varchars = string.ascii_letters + string.digits + '_-' + if isinstance(path, _unicode): + encoding = sys.getfilesystemencoding() + def getenv(var): + return os.environ[var.encode(encoding)].decode(encoding) + else: + def getenv(var): + return os.environ[var] + res = '' + index = 0 + pathlen = len(path) + while index < pathlen: + c = path[index] + if c == '\'': # no expansion within single quotes + path = path[index + 1:] + pathlen = len(path) + try: + index = path.index('\'') + res = res + '\'' + path[:index + 1] + except ValueError: + res = res + c + path + index = pathlen - 1 + elif c == '%': # variable or '%' + if path[index + 1:index + 2] == '%': + res = res + c + index = index + 1 + else: + path = path[index+1:] + pathlen = len(path) + try: + index = path.index('%') + except ValueError: + res = res + '%' + path + index = pathlen - 1 + else: + var = path[:index] + try: + res = res + getenv(var) + except KeyError: + res = res + '%' + var + '%' + elif c == '$': # variable or '$$' + if path[index + 1:index + 2] == '$': + res = res + c + index = index + 1 + elif path[index + 1:index + 2] == '{': + path = path[index+2:] + pathlen = len(path) + try: + index = path.index('}') + var = path[:index] + try: + res = res + getenv(var) + except KeyError: + res = res + '${' + var + '}' + except ValueError: + res = res + '${' + path + index = pathlen - 1 + else: + var = '' + index = index + 1 + c = path[index:index + 1] + while c != '' and c in varchars: + var = var + c + index = index + 1 + c = path[index:index + 1] + try: + res = res + getenv(var) + except KeyError: + res = res + '$' + var + if c != '': + index = index - 1 + else: + res = res + c + index = index + 1 + return res + + +# Normalize a path, e.g. A//B, A/./B and A/foo/../B all become A\B. +# Previously, this function also truncated pathnames to 8+3 format, +# but as this module is called "ntpath", that's obviously wrong! + +def normpath(path): + """Normalize path, eliminating double slashes, etc.""" + # Preserve unicode (if path is unicode) + backslash, dot = (u'\\', u'.') if isinstance(path, _unicode) else ('\\', '.') + if path.startswith(('\\\\.\\', '\\\\?\\')): + # in the case of paths with these prefixes: + # \\.\ -> device names + # \\?\ -> literal paths + # do not do any normalization, but return the path unchanged + return path + path = path.replace("/", "\\") + prefix, path = splitdrive(path) + # We need to be careful here. If the prefix is empty, and the path starts + # with a backslash, it could either be an absolute path on the current + # drive (\dir1\dir2\file) or a UNC filename (\\server\mount\dir1\file). It + # is therefore imperative NOT to collapse multiple backslashes blindly in + # that case. + # The code below preserves multiple backslashes when there is no drive + # letter. This means that the invalid filename \\\a\b is preserved + # unchanged, where a\\\b is normalised to a\b. It's not clear that there + # is any better behaviour for such edge cases. + if prefix == '': + # No drive letter - preserve initial backslashes + while path[:1] == "\\": + prefix = prefix + backslash + path = path[1:] + else: + # We have a drive letter - collapse initial backslashes + if path.startswith("\\"): + prefix = prefix + backslash + path = path.lstrip("\\") + comps = path.split("\\") + i = 0 + while i < len(comps): + if comps[i] in ('.', ''): + del comps[i] + elif comps[i] == '..': + if i > 0 and comps[i-1] != '..': + del comps[i-1:i+1] + i -= 1 + elif i == 0 and prefix.endswith("\\"): + del comps[i] + else: + i += 1 + else: + i += 1 + # If the path is now empty, substitute '.' + if not prefix and not comps: + comps.append(dot) + return prefix + backslash.join(comps) + + +# Return an absolute path. +try: + from nt import _getfullpathname + +except ImportError: # not running on Windows - mock up something sensible + def abspath(path): + """Return the absolute version of a path.""" + if not isabs(path): + if isinstance(path, _unicode): + cwd = os.getcwdu() + else: + cwd = os.getcwd() + path = join(cwd, path) + return normpath(path) + +else: # use native Windows method on Windows + def abspath(path): + """Return the absolute version of a path.""" + + if path: # Empty path must return current working directory. + try: + path = _getfullpathname(path) + except WindowsError: + pass # Bad path - return unchanged. + elif isinstance(path, _unicode): + path = os.getcwdu() + else: + path = os.getcwd() + return normpath(path) + +# realpath is a no-op on systems without islink support +realpath = abspath +# Win9x family and earlier have no Unicode filename support. +supports_unicode_filenames = (hasattr(sys, "getwindowsversion") and + sys.getwindowsversion()[3] >= 2) + +def _abspath_split(path): + abs = abspath(normpath(path)) + prefix, rest = splitunc(abs) + is_unc = bool(prefix) + if not is_unc: + prefix, rest = splitdrive(abs) + return is_unc, prefix, [x for x in rest.split(sep) if x] + +def relpath(path, start=curdir): + """Return a relative version of a path""" + + if not path: + raise ValueError("no path specified") + + start_is_unc, start_prefix, start_list = _abspath_split(start) + path_is_unc, path_prefix, path_list = _abspath_split(path) + + if path_is_unc ^ start_is_unc: + raise ValueError("Cannot mix UNC and non-UNC paths (%s and %s)" + % (path, start)) + if path_prefix.lower() != start_prefix.lower(): + if path_is_unc: + raise ValueError("path is on UNC root %s, start on UNC root %s" + % (path_prefix, start_prefix)) + else: + raise ValueError("path is on drive %s, start on drive %s" + % (path_prefix, start_prefix)) + # Work out how much of the filepath is shared by start and path. + i = 0 + for e1, e2 in zip(start_list, path_list): + if e1.lower() != e2.lower(): + break + i += 1 + + rel_list = [pardir] * (len(start_list)-i) + path_list[i:] + if not rel_list: + return curdir + return join(*rel_list) + +try: + # The genericpath.isdir implementation uses os.stat and checks the mode + # attribute to tell whether or not the path is a directory. + # This is overkill on Windows - just pass the path to GetFileAttributes + # and check the attribute from there. + from nt import _isdir as isdir +except ImportError: + # Use genericpath.isdir as imported above. + pass diff --git a/nturl2path.py b/nturl2path.py new file mode 100644 index 0000000..9e6eb0d --- /dev/null +++ b/nturl2path.py @@ -0,0 +1,68 @@ +"""Convert a NT pathname to a file URL and vice versa.""" + +def url2pathname(url): + """OS-specific conversion from a relative URL of the 'file' scheme + to a file system path; not recommended for general use.""" + # e.g. + # ///C|/foo/bar/spam.foo + # and + # ///C:/foo/bar/spam.foo + # become + # C:\foo\bar\spam.foo + import string, urllib + # Windows itself uses ":" even in URLs. + url = url.replace(':', '|') + if not '|' in url: + # No drive specifier, just convert slashes + if url[:4] == '////': + # path is something like ////host/path/on/remote/host + # convert this to \\host\path\on\remote\host + # (notice halving of slashes at the start of the path) + url = url[2:] + components = url.split('/') + # make sure not to convert quoted slashes :-) + return urllib.unquote('\\'.join(components)) + comp = url.split('|') + if len(comp) != 2 or comp[0][-1] not in string.ascii_letters: + error = 'Bad URL: ' + url + raise IOError, error + drive = comp[0][-1].upper() + path = drive + ':' + components = comp[1].split('/') + for comp in components: + if comp: + path = path + '\\' + urllib.unquote(comp) + # Issue #11474: url like '/C|/' should convert into 'C:\\' + if path.endswith(':') and url.endswith('/'): + path += '\\' + return path + +def pathname2url(p): + """OS-specific conversion from a file system path to a relative URL + of the 'file' scheme; not recommended for general use.""" + # e.g. + # C:\foo\bar\spam.foo + # becomes + # ///C:/foo/bar/spam.foo + import urllib + if not ':' in p: + # No drive specifier, just convert slashes and quote the name + if p[:2] == '\\\\': + # path is something like \\host\path\on\remote\host + # convert this to ////host/path/on/remote/host + # (notice doubling of slashes at the start of the path) + p = '\\\\' + p + components = p.split('\\') + return urllib.quote('/'.join(components)) + comp = p.split(':') + if len(comp) != 2 or len(comp[0]) > 1: + error = 'Bad path: ' + p + raise IOError, error + + drive = urllib.quote(comp[0].upper()) + components = comp[1].split('\\') + path = '///' + drive + ':' + for comp in components: + if comp: + path = path + '/' + urllib.quote(comp) + return path diff --git a/os.py b/os.py new file mode 100644 index 0000000..cfea71b --- /dev/null +++ b/os.py @@ -0,0 +1,742 @@ +r"""OS routines for NT or Posix depending on what system we're on. + +This exports: + - all functions from posix, nt, os2, or ce, e.g. unlink, stat, etc. + - os.path is one of the modules posixpath, or ntpath + - os.name is 'posix', 'nt', 'os2', 'ce' or 'riscos' + - os.curdir is a string representing the current directory ('.' or ':') + - os.pardir is a string representing the parent directory ('..' or '::') + - os.sep is the (or a most common) pathname separator ('/' or ':' or '\\') + - os.extsep is the extension separator ('.' or '/') + - os.altsep is the alternate pathname separator (None or '/') + - os.pathsep is the component separator used in $PATH etc + - os.linesep is the line separator in text files ('\r' or '\n' or '\r\n') + - os.defpath is the default search path for executables + - os.devnull is the file path of the null device ('/dev/null', etc.) + +Programs that import and use 'os' stand a better chance of being +portable between different platforms. Of course, they must then +only use functions that are defined by all platforms (e.g., unlink +and opendir), and leave all pathname manipulation to os.path +(e.g., split and join). +""" + +#' + +import sys, errno + +_names = sys.builtin_module_names + +# Note: more names are added to __all__ later. +__all__ = ["altsep", "curdir", "pardir", "sep", "extsep", "pathsep", "linesep", + "defpath", "name", "path", "devnull", + "SEEK_SET", "SEEK_CUR", "SEEK_END"] + +def _get_exports_list(module): + try: + return list(module.__all__) + except AttributeError: + return [n for n in dir(module) if n[0] != '_'] + +if 'posix' in _names: + name = 'posix' + linesep = '\n' + from posix import * + try: + from posix import _exit + except ImportError: + pass + import posixpath as path + + import posix + __all__.extend(_get_exports_list(posix)) + del posix + +elif 'nt' in _names: + name = 'nt' + linesep = '\r\n' + from nt import * + try: + from nt import _exit + except ImportError: + pass + import ntpath as path + + import nt + __all__.extend(_get_exports_list(nt)) + del nt + +elif 'os2' in _names: + name = 'os2' + linesep = '\r\n' + from os2 import * + try: + from os2 import _exit + except ImportError: + pass + if sys.version.find('EMX GCC') == -1: + import ntpath as path + else: + import os2emxpath as path + from _emx_link import link + + import os2 + __all__.extend(_get_exports_list(os2)) + del os2 + +elif 'ce' in _names: + name = 'ce' + linesep = '\r\n' + from ce import * + try: + from ce import _exit + except ImportError: + pass + # We can use the standard Windows path. + import ntpath as path + + import ce + __all__.extend(_get_exports_list(ce)) + del ce + +elif 'riscos' in _names: + name = 'riscos' + linesep = '\n' + from riscos import * + try: + from riscos import _exit + except ImportError: + pass + import riscospath as path + + import riscos + __all__.extend(_get_exports_list(riscos)) + del riscos + +else: + raise ImportError, 'no os specific module found' + +sys.modules['os.path'] = path +from os.path import (curdir, pardir, sep, pathsep, defpath, extsep, altsep, + devnull) + +del _names + +# Python uses fixed values for the SEEK_ constants; they are mapped +# to native constants if necessary in posixmodule.c +SEEK_SET = 0 +SEEK_CUR = 1 +SEEK_END = 2 + +#' + +# Super directory utilities. +# (Inspired by Eric Raymond; the doc strings are mostly his) + +def makedirs(name, mode=0777): + """makedirs(path [, mode=0777]) + + Super-mkdir; create a leaf directory and all intermediate ones. + Works like mkdir, except that any intermediate path segment (not + just the rightmost) will be created if it does not exist. This is + recursive. + + """ + head, tail = path.split(name) + if not tail: + head, tail = path.split(head) + if head and tail and not path.exists(head): + try: + makedirs(head, mode) + except OSError, e: + # be happy if someone already created the path + if e.errno != errno.EEXIST: + raise + if tail == curdir: # xxx/newdir/. exists if xxx/newdir exists + return + mkdir(name, mode) + +def removedirs(name): + """removedirs(path) + + Super-rmdir; remove a leaf directory and all empty intermediate + ones. Works like rmdir except that, if the leaf directory is + successfully removed, directories corresponding to rightmost path + segments will be pruned away until either the whole path is + consumed or an error occurs. Errors during this latter phase are + ignored -- they generally mean that a directory was not empty. + + """ + rmdir(name) + head, tail = path.split(name) + if not tail: + head, tail = path.split(head) + while head and tail: + try: + rmdir(head) + except error: + break + head, tail = path.split(head) + +def renames(old, new): + """renames(old, new) + + Super-rename; create directories as necessary and delete any left + empty. Works like rename, except creation of any intermediate + directories needed to make the new pathname good is attempted + first. After the rename, directories corresponding to rightmost + path segments of the old name will be pruned until either the + whole path is consumed or a nonempty directory is found. + + Note: this function can fail with the new directory structure made + if you lack permissions needed to unlink the leaf directory or + file. + + """ + head, tail = path.split(new) + if head and tail and not path.exists(head): + makedirs(head) + rename(old, new) + head, tail = path.split(old) + if head and tail: + try: + removedirs(head) + except error: + pass + +__all__.extend(["makedirs", "removedirs", "renames"]) + +def walk(top, topdown=True, onerror=None, followlinks=False): + """Directory tree generator. + + For each directory in the directory tree rooted at top (including top + itself, but excluding '.' and '..'), yields a 3-tuple + + dirpath, dirnames, filenames + + dirpath is a string, the path to the directory. dirnames is a list of + the names of the subdirectories in dirpath (excluding '.' and '..'). + filenames is a list of the names of the non-directory files in dirpath. + Note that the names in the lists are just names, with no path components. + To get a full path (which begins with top) to a file or directory in + dirpath, do os.path.join(dirpath, name). + + If optional arg 'topdown' is true or not specified, the triple for a + directory is generated before the triples for any of its subdirectories + (directories are generated top down). If topdown is false, the triple + for a directory is generated after the triples for all of its + subdirectories (directories are generated bottom up). + + When topdown is true, the caller can modify the dirnames list in-place + (e.g., via del or slice assignment), and walk will only recurse into the + subdirectories whose names remain in dirnames; this can be used to prune the + search, or to impose a specific order of visiting. Modifying dirnames when + topdown is false is ineffective, since the directories in dirnames have + already been generated by the time dirnames itself is generated. No matter + the value of topdown, the list of subdirectories is retrieved before the + tuples for the directory and its subdirectories are generated. + + By default errors from the os.listdir() call are ignored. If + optional arg 'onerror' is specified, it should be a function; it + will be called with one argument, an os.error instance. It can + report the error to continue with the walk, or raise the exception + to abort the walk. Note that the filename is available as the + filename attribute of the exception object. + + By default, os.walk does not follow symbolic links to subdirectories on + systems that support them. In order to get this functionality, set the + optional argument 'followlinks' to true. + + Caution: if you pass a relative pathname for top, don't change the + current working directory between resumptions of walk. walk never + changes the current directory, and assumes that the client doesn't + either. + + Example: + + import os + from os.path import join, getsize + for root, dirs, files in os.walk('python/Lib/email'): + print root, "consumes", + print sum([getsize(join(root, name)) for name in files]), + print "bytes in", len(files), "non-directory files" + if 'CVS' in dirs: + dirs.remove('CVS') # don't visit CVS directories + + """ + + islink, join, isdir = path.islink, path.join, path.isdir + + # We may not have read permission for top, in which case we can't + # get a list of the files the directory contains. os.path.walk + # always suppressed the exception then, rather than blow up for a + # minor reason when (say) a thousand readable directories are still + # left to visit. That logic is copied here. + try: + # Note that listdir and error are globals in this module due + # to earlier import-*. + names = listdir(top) + except error, err: + if onerror is not None: + onerror(err) + return + + dirs, nondirs = [], [] + for name in names: + if isdir(join(top, name)): + dirs.append(name) + else: + nondirs.append(name) + + if topdown: + yield top, dirs, nondirs + for name in dirs: + new_path = join(top, name) + if followlinks or not islink(new_path): + for x in walk(new_path, topdown, onerror, followlinks): + yield x + if not topdown: + yield top, dirs, nondirs + +__all__.append("walk") + +# Make sure os.environ exists, at least +try: + environ +except NameError: + environ = {} + +def execl(file, *args): + """execl(file, *args) + + Execute the executable file with argument list args, replacing the + current process. """ + execv(file, args) + +def execle(file, *args): + """execle(file, *args, env) + + Execute the executable file with argument list args and + environment env, replacing the current process. """ + env = args[-1] + execve(file, args[:-1], env) + +def execlp(file, *args): + """execlp(file, *args) + + Execute the executable file (which is searched for along $PATH) + with argument list args, replacing the current process. """ + execvp(file, args) + +def execlpe(file, *args): + """execlpe(file, *args, env) + + Execute the executable file (which is searched for along $PATH) + with argument list args and environment env, replacing the current + process. """ + env = args[-1] + execvpe(file, args[:-1], env) + +def execvp(file, args): + """execvp(file, args) + + Execute the executable file (which is searched for along $PATH) + with argument list args, replacing the current process. + args may be a list or tuple of strings. """ + _execvpe(file, args) + +def execvpe(file, args, env): + """execvpe(file, args, env) + + Execute the executable file (which is searched for along $PATH) + with argument list args and environment env , replacing the + current process. + args may be a list or tuple of strings. """ + _execvpe(file, args, env) + +__all__.extend(["execl","execle","execlp","execlpe","execvp","execvpe"]) + +def _execvpe(file, args, env=None): + if env is not None: + func = execve + argrest = (args, env) + else: + func = execv + argrest = (args,) + env = environ + + head, tail = path.split(file) + if head: + func(file, *argrest) + return + if 'PATH' in env: + envpath = env['PATH'] + else: + envpath = defpath + PATH = envpath.split(pathsep) + saved_exc = None + saved_tb = None + for dir in PATH: + fullname = path.join(dir, file) + try: + func(fullname, *argrest) + except error, e: + tb = sys.exc_info()[2] + if (e.errno != errno.ENOENT and e.errno != errno.ENOTDIR + and saved_exc is None): + saved_exc = e + saved_tb = tb + if saved_exc: + raise error, saved_exc, saved_tb + raise error, e, tb + +# Change environ to automatically call putenv() if it exists +try: + # This will fail if there's no putenv + putenv +except NameError: + pass +else: + import UserDict + + # Fake unsetenv() for Windows + # not sure about os2 here but + # I'm guessing they are the same. + + if name in ('os2', 'nt'): + def unsetenv(key): + putenv(key, "") + + if name == "riscos": + # On RISC OS, all env access goes through getenv and putenv + from riscosenviron import _Environ + elif name in ('os2', 'nt'): # Where Env Var Names Must Be UPPERCASE + # But we store them as upper case + class _Environ(UserDict.IterableUserDict): + def __init__(self, environ): + UserDict.UserDict.__init__(self) + data = self.data + for k, v in environ.items(): + data[k.upper()] = v + def __setitem__(self, key, item): + putenv(key, item) + self.data[key.upper()] = item + def __getitem__(self, key): + return self.data[key.upper()] + try: + unsetenv + except NameError: + def __delitem__(self, key): + del self.data[key.upper()] + else: + def __delitem__(self, key): + unsetenv(key) + del self.data[key.upper()] + def clear(self): + for key in self.data.keys(): + unsetenv(key) + del self.data[key] + def pop(self, key, *args): + unsetenv(key) + return self.data.pop(key.upper(), *args) + def has_key(self, key): + return key.upper() in self.data + def __contains__(self, key): + return key.upper() in self.data + def get(self, key, failobj=None): + return self.data.get(key.upper(), failobj) + def update(self, dict=None, **kwargs): + if dict: + try: + keys = dict.keys() + except AttributeError: + # List of (key, value) + for k, v in dict: + self[k] = v + else: + # got keys + # cannot use items(), since mappings + # may not have them. + for k in keys: + self[k] = dict[k] + if kwargs: + self.update(kwargs) + def copy(self): + return dict(self) + + else: # Where Env Var Names Can Be Mixed Case + class _Environ(UserDict.IterableUserDict): + def __init__(self, environ): + UserDict.UserDict.__init__(self) + self.data = environ + def __setitem__(self, key, item): + putenv(key, item) + self.data[key] = item + def update(self, dict=None, **kwargs): + if dict: + try: + keys = dict.keys() + except AttributeError: + # List of (key, value) + for k, v in dict: + self[k] = v + else: + # got keys + # cannot use items(), since mappings + # may not have them. + for k in keys: + self[k] = dict[k] + if kwargs: + self.update(kwargs) + try: + unsetenv + except NameError: + pass + else: + def __delitem__(self, key): + unsetenv(key) + del self.data[key] + def clear(self): + for key in self.data.keys(): + unsetenv(key) + del self.data[key] + def pop(self, key, *args): + unsetenv(key) + return self.data.pop(key, *args) + def copy(self): + return dict(self) + + + environ = _Environ(environ) + +def getenv(key, default=None): + """Get an environment variable, return None if it doesn't exist. + The optional second argument can specify an alternate default.""" + return environ.get(key, default) +__all__.append("getenv") + +def _exists(name): + return name in globals() + +# Supply spawn*() (probably only for Unix) +if _exists("fork") and not _exists("spawnv") and _exists("execv"): + + P_WAIT = 0 + P_NOWAIT = P_NOWAITO = 1 + + # XXX Should we support P_DETACH? I suppose it could fork()**2 + # and close the std I/O streams. Also, P_OVERLAY is the same + # as execv*()? + + def _spawnvef(mode, file, args, env, func): + # Internal helper; func is the exec*() function to use + pid = fork() + if not pid: + # Child + try: + if env is None: + func(file, args) + else: + func(file, args, env) + except: + _exit(127) + else: + # Parent + if mode == P_NOWAIT: + return pid # Caller is responsible for waiting! + while 1: + wpid, sts = waitpid(pid, 0) + if WIFSTOPPED(sts): + continue + elif WIFSIGNALED(sts): + return -WTERMSIG(sts) + elif WIFEXITED(sts): + return WEXITSTATUS(sts) + else: + raise error, "Not stopped, signaled or exited???" + + def spawnv(mode, file, args): + """spawnv(mode, file, args) -> integer + +Execute file with arguments from args in a subprocess. +If mode == P_NOWAIT return the pid of the process. +If mode == P_WAIT return the process's exit code if it exits normally; +otherwise return -SIG, where SIG is the signal that killed it. """ + return _spawnvef(mode, file, args, None, execv) + + def spawnve(mode, file, args, env): + """spawnve(mode, file, args, env) -> integer + +Execute file with arguments from args in a subprocess with the +specified environment. +If mode == P_NOWAIT return the pid of the process. +If mode == P_WAIT return the process's exit code if it exits normally; +otherwise return -SIG, where SIG is the signal that killed it. """ + return _spawnvef(mode, file, args, env, execve) + + # Note: spawnvp[e] is't currently supported on Windows + + def spawnvp(mode, file, args): + """spawnvp(mode, file, args) -> integer + +Execute file (which is looked for along $PATH) with arguments from +args in a subprocess. +If mode == P_NOWAIT return the pid of the process. +If mode == P_WAIT return the process's exit code if it exits normally; +otherwise return -SIG, where SIG is the signal that killed it. """ + return _spawnvef(mode, file, args, None, execvp) + + def spawnvpe(mode, file, args, env): + """spawnvpe(mode, file, args, env) -> integer + +Execute file (which is looked for along $PATH) with arguments from +args in a subprocess with the supplied environment. +If mode == P_NOWAIT return the pid of the process. +If mode == P_WAIT return the process's exit code if it exits normally; +otherwise return -SIG, where SIG is the signal that killed it. """ + return _spawnvef(mode, file, args, env, execvpe) + +if _exists("spawnv"): + # These aren't supplied by the basic Windows code + # but can be easily implemented in Python + + def spawnl(mode, file, *args): + """spawnl(mode, file, *args) -> integer + +Execute file with arguments from args in a subprocess. +If mode == P_NOWAIT return the pid of the process. +If mode == P_WAIT return the process's exit code if it exits normally; +otherwise return -SIG, where SIG is the signal that killed it. """ + return spawnv(mode, file, args) + + def spawnle(mode, file, *args): + """spawnle(mode, file, *args, env) -> integer + +Execute file with arguments from args in a subprocess with the +supplied environment. +If mode == P_NOWAIT return the pid of the process. +If mode == P_WAIT return the process's exit code if it exits normally; +otherwise return -SIG, where SIG is the signal that killed it. """ + env = args[-1] + return spawnve(mode, file, args[:-1], env) + + + __all__.extend(["spawnv", "spawnve", "spawnl", "spawnle",]) + + +if _exists("spawnvp"): + # At the moment, Windows doesn't implement spawnvp[e], + # so it won't have spawnlp[e] either. + def spawnlp(mode, file, *args): + """spawnlp(mode, file, *args) -> integer + +Execute file (which is looked for along $PATH) with arguments from +args in a subprocess with the supplied environment. +If mode == P_NOWAIT return the pid of the process. +If mode == P_WAIT return the process's exit code if it exits normally; +otherwise return -SIG, where SIG is the signal that killed it. """ + return spawnvp(mode, file, args) + + def spawnlpe(mode, file, *args): + """spawnlpe(mode, file, *args, env) -> integer + +Execute file (which is looked for along $PATH) with arguments from +args in a subprocess with the supplied environment. +If mode == P_NOWAIT return the pid of the process. +If mode == P_WAIT return the process's exit code if it exits normally; +otherwise return -SIG, where SIG is the signal that killed it. """ + env = args[-1] + return spawnvpe(mode, file, args[:-1], env) + + + __all__.extend(["spawnvp", "spawnvpe", "spawnlp", "spawnlpe",]) + + +# Supply popen2 etc. (for Unix) +if _exists("fork"): + if not _exists("popen2"): + def popen2(cmd, mode="t", bufsize=-1): + """Execute the shell command 'cmd' in a sub-process. On UNIX, 'cmd' + may be a sequence, in which case arguments will be passed directly to + the program without shell intervention (as with os.spawnv()). If 'cmd' + is a string it will be passed to the shell (as with os.system()). If + 'bufsize' is specified, it sets the buffer size for the I/O pipes. The + file objects (child_stdin, child_stdout) are returned.""" + import warnings + msg = "os.popen2 is deprecated. Use the subprocess module." + warnings.warn(msg, DeprecationWarning, stacklevel=2) + + import subprocess + PIPE = subprocess.PIPE + p = subprocess.Popen(cmd, shell=isinstance(cmd, basestring), + bufsize=bufsize, stdin=PIPE, stdout=PIPE, + close_fds=True) + return p.stdin, p.stdout + __all__.append("popen2") + + if not _exists("popen3"): + def popen3(cmd, mode="t", bufsize=-1): + """Execute the shell command 'cmd' in a sub-process. On UNIX, 'cmd' + may be a sequence, in which case arguments will be passed directly to + the program without shell intervention (as with os.spawnv()). If 'cmd' + is a string it will be passed to the shell (as with os.system()). If + 'bufsize' is specified, it sets the buffer size for the I/O pipes. The + file objects (child_stdin, child_stdout, child_stderr) are returned.""" + import warnings + msg = "os.popen3 is deprecated. Use the subprocess module." + warnings.warn(msg, DeprecationWarning, stacklevel=2) + + import subprocess + PIPE = subprocess.PIPE + p = subprocess.Popen(cmd, shell=isinstance(cmd, basestring), + bufsize=bufsize, stdin=PIPE, stdout=PIPE, + stderr=PIPE, close_fds=True) + return p.stdin, p.stdout, p.stderr + __all__.append("popen3") + + if not _exists("popen4"): + def popen4(cmd, mode="t", bufsize=-1): + """Execute the shell command 'cmd' in a sub-process. On UNIX, 'cmd' + may be a sequence, in which case arguments will be passed directly to + the program without shell intervention (as with os.spawnv()). If 'cmd' + is a string it will be passed to the shell (as with os.system()). If + 'bufsize' is specified, it sets the buffer size for the I/O pipes. The + file objects (child_stdin, child_stdout_stderr) are returned.""" + import warnings + msg = "os.popen4 is deprecated. Use the subprocess module." + warnings.warn(msg, DeprecationWarning, stacklevel=2) + + import subprocess + PIPE = subprocess.PIPE + p = subprocess.Popen(cmd, shell=isinstance(cmd, basestring), + bufsize=bufsize, stdin=PIPE, stdout=PIPE, + stderr=subprocess.STDOUT, close_fds=True) + return p.stdin, p.stdout + __all__.append("popen4") + +import copy_reg as _copy_reg + +def _make_stat_result(tup, dict): + return stat_result(tup, dict) + +def _pickle_stat_result(sr): + (type, args) = sr.__reduce__() + return (_make_stat_result, args) + +try: + _copy_reg.pickle(stat_result, _pickle_stat_result, _make_stat_result) +except NameError: # stat_result may not exist + pass + +def _make_statvfs_result(tup, dict): + return statvfs_result(tup, dict) + +def _pickle_statvfs_result(sr): + (type, args) = sr.__reduce__() + return (_make_statvfs_result, args) + +try: + _copy_reg.pickle(statvfs_result, _pickle_statvfs_result, + _make_statvfs_result) +except NameError: # statvfs_result may not exist + pass diff --git a/posixpath.py b/posixpath.py new file mode 100644 index 0000000..6578481 --- /dev/null +++ b/posixpath.py @@ -0,0 +1,439 @@ +"""Common operations on Posix pathnames. + +Instead of importing this module directly, import os and refer to +this module as os.path. The "os.path" name is an alias for this +module on Posix systems; on other systems (e.g. Mac, Windows), +os.path provides the same operations in a manner specific to that +platform, and is an alias to another module (e.g. macpath, ntpath). + +Some of this can actually be useful on non-Posix systems too, e.g. +for manipulation of the pathname component of URLs. +""" + +import os +import sys +import stat +import genericpath +import warnings +from genericpath import * +from genericpath import _unicode + +__all__ = ["normcase","isabs","join","splitdrive","split","splitext", + "basename","dirname","commonprefix","getsize","getmtime", + "getatime","getctime","islink","exists","lexists","isdir","isfile", + "ismount","walk","expanduser","expandvars","normpath","abspath", + "samefile","sameopenfile","samestat", + "curdir","pardir","sep","pathsep","defpath","altsep","extsep", + "devnull","realpath","supports_unicode_filenames","relpath"] + +# strings representing various path-related bits and pieces +curdir = '.' +pardir = '..' +extsep = '.' +sep = '/' +pathsep = ':' +defpath = ':/bin:/usr/bin' +altsep = None +devnull = '/dev/null' + +# Normalize the case of a pathname. Trivial in Posix, string.lower on Mac. +# On MS-DOS this may also turn slashes into backslashes; however, other +# normalizations (such as optimizing '../' away) are not allowed +# (another function should be defined to do that). + +def normcase(s): + """Normalize case of pathname. Has no effect under Posix""" + return s + + +# Return whether a path is absolute. +# Trivial in Posix, harder on the Mac or MS-DOS. + +def isabs(s): + """Test whether a path is absolute""" + return s.startswith('/') + + +# Join pathnames. +# Ignore the previous parts if a part is absolute. +# Insert a '/' unless the first part is empty or already ends in '/'. + +def join(a, *p): + """Join two or more pathname components, inserting '/' as needed. + If any component is an absolute path, all previous path components + will be discarded. An empty last part will result in a path that + ends with a separator.""" + path = a + for b in p: + if b.startswith('/'): + path = b + elif path == '' or path.endswith('/'): + path += b + else: + path += '/' + b + return path + + +# Split a path in head (everything up to the last '/') and tail (the +# rest). If the path ends in '/', tail will be empty. If there is no +# '/' in the path, head will be empty. +# Trailing '/'es are stripped from head unless it is the root. + +def split(p): + """Split a pathname. Returns tuple "(head, tail)" where "tail" is + everything after the final slash. Either part may be empty.""" + i = p.rfind('/') + 1 + head, tail = p[:i], p[i:] + if head and head != '/'*len(head): + head = head.rstrip('/') + return head, tail + + +# Split a path in root and extension. +# The extension is everything starting at the last dot in the last +# pathname component; the root is everything before that. +# It is always true that root + ext == p. + +def splitext(p): + return genericpath._splitext(p, sep, altsep, extsep) +splitext.__doc__ = genericpath._splitext.__doc__ + +# Split a pathname into a drive specification and the rest of the +# path. Useful on DOS/Windows/NT; on Unix, the drive is always empty. + +def splitdrive(p): + """Split a pathname into drive and path. On Posix, drive is always + empty.""" + return '', p + + +# Return the tail (basename) part of a path, same as split(path)[1]. + +def basename(p): + """Returns the final component of a pathname""" + i = p.rfind('/') + 1 + return p[i:] + + +# Return the head (dirname) part of a path, same as split(path)[0]. + +def dirname(p): + """Returns the directory component of a pathname""" + i = p.rfind('/') + 1 + head = p[:i] + if head and head != '/'*len(head): + head = head.rstrip('/') + return head + + +# Is a path a symbolic link? +# This will always return false on systems where os.lstat doesn't exist. + +def islink(path): + """Test whether a path is a symbolic link""" + try: + st = os.lstat(path) + except (os.error, AttributeError): + return False + return stat.S_ISLNK(st.st_mode) + +# Being true for dangling symbolic links is also useful. + +def lexists(path): + """Test whether a path exists. Returns True for broken symbolic links""" + try: + os.lstat(path) + except os.error: + return False + return True + + +# Are two filenames really pointing to the same file? + +def samefile(f1, f2): + """Test whether two pathnames reference the same actual file""" + s1 = os.stat(f1) + s2 = os.stat(f2) + return samestat(s1, s2) + + +# Are two open files really referencing the same file? +# (Not necessarily the same file descriptor!) + +def sameopenfile(fp1, fp2): + """Test whether two open file objects reference the same file""" + s1 = os.fstat(fp1) + s2 = os.fstat(fp2) + return samestat(s1, s2) + + +# Are two stat buffers (obtained from stat, fstat or lstat) +# describing the same file? + +def samestat(s1, s2): + """Test whether two stat buffers reference the same file""" + return s1.st_ino == s2.st_ino and \ + s1.st_dev == s2.st_dev + + +# Is a path a mount point? +# (Does this work for all UNIXes? Is it even guaranteed to work by Posix?) + +def ismount(path): + """Test whether a path is a mount point""" + if islink(path): + # A symlink can never be a mount point + return False + try: + s1 = os.lstat(path) + s2 = os.lstat(join(path, '..')) + except os.error: + return False # It doesn't exist -- so not a mount point :-) + dev1 = s1.st_dev + dev2 = s2.st_dev + if dev1 != dev2: + return True # path/.. on a different device as path + ino1 = s1.st_ino + ino2 = s2.st_ino + if ino1 == ino2: + return True # path/.. is the same i-node as path + return False + + +# Directory tree walk. +# For each directory under top (including top itself, but excluding +# '.' and '..'), func(arg, dirname, filenames) is called, where +# dirname is the name of the directory and filenames is the list +# of files (and subdirectories etc.) in the directory. +# The func may modify the filenames list, to implement a filter, +# or to impose a different order of visiting. + +def walk(top, func, arg): + """Directory tree walk with callback function. + + For each directory in the directory tree rooted at top (including top + itself, but excluding '.' and '..'), call func(arg, dirname, fnames). + dirname is the name of the directory, and fnames a list of the names of + the files and subdirectories in dirname (excluding '.' and '..'). func + may modify the fnames list in-place (e.g. via del or slice assignment), + and walk will only recurse into the subdirectories whose names remain in + fnames; this can be used to implement a filter, or to impose a specific + order of visiting. No semantics are defined for, or required of, arg, + beyond that arg is always passed to func. It can be used, e.g., to pass + a filename pattern, or a mutable object designed to accumulate + statistics. Passing None for arg is common.""" + warnings.warnpy3k("In 3.x, os.path.walk is removed in favor of os.walk.", + stacklevel=2) + try: + names = os.listdir(top) + except os.error: + return + func(arg, top, names) + for name in names: + name = join(top, name) + try: + st = os.lstat(name) + except os.error: + continue + if stat.S_ISDIR(st.st_mode): + walk(name, func, arg) + + +# Expand paths beginning with '~' or '~user'. +# '~' means $HOME; '~user' means that user's home directory. +# If the path doesn't begin with '~', or if the user or $HOME is unknown, +# the path is returned unchanged (leaving error reporting to whatever +# function is called with the expanded path as argument). +# See also module 'glob' for expansion of *, ? and [...] in pathnames. +# (A function should also be defined to do full *sh-style environment +# variable expansion.) + +def expanduser(path): + """Expand ~ and ~user constructions. If user or $HOME is unknown, + do nothing.""" + if not path.startswith('~'): + return path + i = path.find('/', 1) + if i < 0: + i = len(path) + if i == 1: + if 'HOME' not in os.environ: + import pwd + userhome = pwd.getpwuid(os.getuid()).pw_dir + else: + userhome = os.environ['HOME'] + else: + import pwd + try: + pwent = pwd.getpwnam(path[1:i]) + except KeyError: + return path + userhome = pwent.pw_dir + userhome = userhome.rstrip('/') + return (userhome + path[i:]) or '/' + + +# Expand paths containing shell variable substitutions. +# This expands the forms $variable and ${variable} only. +# Non-existent variables are left unchanged. + +_varprog = None +_uvarprog = None + +def expandvars(path): + """Expand shell variables of form $var and ${var}. Unknown variables + are left unchanged.""" + global _varprog, _uvarprog + if '$' not in path: + return path + if isinstance(path, _unicode): + if not _uvarprog: + import re + _uvarprog = re.compile(ur'\$(\w+|\{[^}]*\})', re.UNICODE) + varprog = _uvarprog + encoding = sys.getfilesystemencoding() + else: + if not _varprog: + import re + _varprog = re.compile(r'\$(\w+|\{[^}]*\})') + varprog = _varprog + encoding = None + i = 0 + while True: + m = varprog.search(path, i) + if not m: + break + i, j = m.span(0) + name = m.group(1) + if name.startswith('{') and name.endswith('}'): + name = name[1:-1] + if encoding: + name = name.encode(encoding) + if name in os.environ: + tail = path[j:] + value = os.environ[name] + if encoding: + value = value.decode(encoding) + path = path[:i] + value + i = len(path) + path += tail + else: + i = j + return path + + +# Normalize a path, e.g. A//B, A/./B and A/foo/../B all become A/B. +# It should be understood that this may change the meaning of the path +# if it contains symbolic links! + +def normpath(path): + """Normalize path, eliminating double slashes, etc.""" + # Preserve unicode (if path is unicode) + slash, dot = (u'/', u'.') if isinstance(path, _unicode) else ('/', '.') + if path == '': + return dot + initial_slashes = path.startswith('/') + # POSIX allows one or two initial slashes, but treats three or more + # as single slash. + if (initial_slashes and + path.startswith('//') and not path.startswith('///')): + initial_slashes = 2 + comps = path.split('/') + new_comps = [] + for comp in comps: + if comp in ('', '.'): + continue + if (comp != '..' or (not initial_slashes and not new_comps) or + (new_comps and new_comps[-1] == '..')): + new_comps.append(comp) + elif new_comps: + new_comps.pop() + comps = new_comps + path = slash.join(comps) + if initial_slashes: + path = slash*initial_slashes + path + return path or dot + + +def abspath(path): + """Return an absolute path.""" + if not isabs(path): + if isinstance(path, _unicode): + cwd = os.getcwdu() + else: + cwd = os.getcwd() + path = join(cwd, path) + return normpath(path) + + +# Return a canonical path (i.e. the absolute location of a file on the +# filesystem). + +def realpath(filename): + """Return the canonical path of the specified filename, eliminating any +symbolic links encountered in the path.""" + path, ok = _joinrealpath('', filename, {}) + return abspath(path) + +# Join two paths, normalizing ang eliminating any symbolic links +# encountered in the second path. +def _joinrealpath(path, rest, seen): + if isabs(rest): + rest = rest[1:] + path = sep + + while rest: + name, _, rest = rest.partition(sep) + if not name or name == curdir: + # current dir + continue + if name == pardir: + # parent dir + if path: + path, name = split(path) + if name == pardir: + path = join(path, pardir, pardir) + else: + path = pardir + continue + newpath = join(path, name) + if not islink(newpath): + path = newpath + continue + # Resolve the symbolic link + if newpath in seen: + # Already seen this path + path = seen[newpath] + if path is not None: + # use cached value + continue + # The symlink is not resolved, so we must have a symlink loop. + # Return already resolved part + rest of the path unchanged. + return join(newpath, rest), False + seen[newpath] = None # not resolved symlink + path, ok = _joinrealpath(path, os.readlink(newpath), seen) + if not ok: + return join(path, rest), False + seen[newpath] = path # resolved symlink + + return path, True + + +supports_unicode_filenames = (sys.platform == 'darwin') + +def relpath(path, start=curdir): + """Return a relative version of a path""" + + if not path: + raise ValueError("no path specified") + + start_list = [x for x in abspath(start).split(sep) if x] + path_list = [x for x in abspath(path).split(sep) if x] + + # Work out how much of the filepath is shared by start and path. + i = len(commonprefix([start_list, path_list])) + + rel_list = [pardir] * (len(start_list)-i) + path_list[i:] + if not rel_list: + return curdir + return join(*rel_list) diff --git a/random.py b/random.py new file mode 100644 index 0000000..3f96a37 --- /dev/null +++ b/random.py @@ -0,0 +1,910 @@ +"""Random variable generators. + + integers + -------- + uniform within range + + sequences + --------- + pick random element + pick random sample + generate random permutation + + distributions on the real line: + ------------------------------ + uniform + triangular + normal (Gaussian) + lognormal + negative exponential + gamma + beta + pareto + Weibull + + distributions on the circle (angles 0 to 2pi) + --------------------------------------------- + circular uniform + von Mises + +General notes on the underlying Mersenne Twister core generator: + +* The period is 2**19937-1. +* It is one of the most extensively tested generators in existence. +* Without a direct way to compute N steps forward, the semantics of + jumpahead(n) are weakened to simply jump to another distant state and rely + on the large period to avoid overlapping sequences. +* The random() method is implemented in C, executes in a single Python step, + and is, therefore, threadsafe. + +""" + +from __future__ import division +from warnings import warn as _warn +from types import MethodType as _MethodType, BuiltinMethodType as _BuiltinMethodType +from math import log as _log, exp as _exp, pi as _pi, e as _e, ceil as _ceil +from math import sqrt as _sqrt, acos as _acos, cos as _cos, sin as _sin +from os import urandom as _urandom +from binascii import hexlify as _hexlify +import hashlib as _hashlib + +__all__ = ["Random","seed","random","uniform","randint","choice","sample", + "randrange","shuffle","normalvariate","lognormvariate", + "expovariate","vonmisesvariate","gammavariate","triangular", + "gauss","betavariate","paretovariate","weibullvariate", + "getstate","setstate","jumpahead", "WichmannHill", "getrandbits", + "SystemRandom"] + +NV_MAGICCONST = 4 * _exp(-0.5)/_sqrt(2.0) +TWOPI = 2.0*_pi +LOG4 = _log(4.0) +SG_MAGICCONST = 1.0 + _log(4.5) +BPF = 53 # Number of bits in a float +RECIP_BPF = 2**-BPF + + +# Translated by Guido van Rossum from C source provided by +# Adrian Baddeley. Adapted by Raymond Hettinger for use with +# the Mersenne Twister and os.urandom() core generators. + +import _random + +class Random(_random.Random): + """Random number generator base class used by bound module functions. + + Used to instantiate instances of Random to get generators that don't + share state. Especially useful for multi-threaded programs, creating + a different instance of Random for each thread, and using the jumpahead() + method to ensure that the generated sequences seen by each thread don't + overlap. + + Class Random can also be subclassed if you want to use a different basic + generator of your own devising: in that case, override the following + methods: random(), seed(), getstate(), setstate() and jumpahead(). + Optionally, implement a getrandbits() method so that randrange() can cover + arbitrarily large ranges. + + """ + + VERSION = 3 # used by getstate/setstate + + def __init__(self, x=None): + """Initialize an instance. + + Optional argument x controls seeding, as for Random.seed(). + """ + + self.seed(x) + self.gauss_next = None + + def seed(self, a=None): + """Initialize internal state from hashable object. + + None or no argument seeds from current time or from an operating + system specific randomness source if available. + + If a is not None or an int or long, hash(a) is used instead. + """ + + if a is None: + try: + # Seed with enough bytes to span the 19937 bit + # state space for the Mersenne Twister + a = long(_hexlify(_urandom(2500)), 16) + except NotImplementedError: + import time + a = long(time.time() * 256) # use fractional seconds + + super(Random, self).seed(a) + self.gauss_next = None + + def getstate(self): + """Return internal state; can be passed to setstate() later.""" + return self.VERSION, super(Random, self).getstate(), self.gauss_next + + def setstate(self, state): + """Restore internal state from object returned by getstate().""" + version = state[0] + if version == 3: + version, internalstate, self.gauss_next = state + super(Random, self).setstate(internalstate) + elif version == 2: + version, internalstate, self.gauss_next = state + # In version 2, the state was saved as signed ints, which causes + # inconsistencies between 32/64-bit systems. The state is + # really unsigned 32-bit ints, so we convert negative ints from + # version 2 to positive longs for version 3. + try: + internalstate = tuple( long(x) % (2**32) for x in internalstate ) + except ValueError, e: + raise TypeError, e + super(Random, self).setstate(internalstate) + else: + raise ValueError("state with version %s passed to " + "Random.setstate() of version %s" % + (version, self.VERSION)) + + def jumpahead(self, n): + """Change the internal state to one that is likely far away + from the current state. This method will not be in Py3.x, + so it is better to simply reseed. + """ + # The super.jumpahead() method uses shuffling to change state, + # so it needs a large and "interesting" n to work with. Here, + # we use hashing to create a large n for the shuffle. + s = repr(n) + repr(self.getstate()) + n = int(_hashlib.new('sha512', s).hexdigest(), 16) + super(Random, self).jumpahead(n) + +## ---- Methods below this point do not need to be overridden when +## ---- subclassing for the purpose of using a different core generator. + +## -------------------- pickle support ------------------- + + def __getstate__(self): # for pickle + return self.getstate() + + def __setstate__(self, state): # for pickle + self.setstate(state) + + def __reduce__(self): + return self.__class__, (), self.getstate() + +## -------------------- integer methods ------------------- + + def randrange(self, start, stop=None, step=1, _int=int, _maxwidth=1L< 0: + if istart >= _maxwidth: + return self._randbelow(istart) + return _int(self.random() * istart) + raise ValueError, "empty range for randrange()" + + # stop argument supplied. + istop = _int(stop) + if istop != stop: + raise ValueError, "non-integer stop for randrange()" + width = istop - istart + if step == 1 and width > 0: + # Note that + # int(istart + self.random()*width) + # instead would be incorrect. For example, consider istart + # = -2 and istop = 0. Then the guts would be in + # -2.0 to 0.0 exclusive on both ends (ignoring that random() + # might return 0.0), and because int() truncates toward 0, the + # final result would be -1 or 0 (instead of -2 or -1). + # istart + int(self.random()*width) + # would also be incorrect, for a subtler reason: the RHS + # can return a long, and then randrange() would also return + # a long, but we're supposed to return an int (for backward + # compatibility). + + if width >= _maxwidth: + return _int(istart + self._randbelow(width)) + return _int(istart + _int(self.random()*width)) + if step == 1: + raise ValueError, "empty range for randrange() (%d,%d, %d)" % (istart, istop, width) + + # Non-unit step argument supplied. + istep = _int(step) + if istep != step: + raise ValueError, "non-integer step for randrange()" + if istep > 0: + n = (width + istep - 1) // istep + elif istep < 0: + n = (width + istep + 1) // istep + else: + raise ValueError, "zero step for randrange()" + + if n <= 0: + raise ValueError, "empty range for randrange()" + + if n >= _maxwidth: + return istart + istep*self._randbelow(n) + return istart + istep*_int(self.random() * n) + + def randint(self, a, b): + """Return random integer in range [a, b], including both end points. + """ + + return self.randrange(a, b+1) + + def _randbelow(self, n, _log=_log, _int=int, _maxwidth=1L< n-1 > 2**(k-2) + r = getrandbits(k) + while r >= n: + r = getrandbits(k) + return r + if n >= _maxwidth: + _warn("Underlying random() generator does not supply \n" + "enough bits to choose from a population range this large") + return _int(self.random() * n) + +## -------------------- sequence methods ------------------- + + def choice(self, seq): + """Choose a random element from a non-empty sequence.""" + return seq[int(self.random() * len(seq))] # raises IndexError if seq is empty + + def shuffle(self, x, random=None): + """x, random=random.random -> shuffle list x in place; return None. + + Optional arg random is a 0-argument function returning a random + float in [0.0, 1.0); by default, the standard random.random. + + """ + + if random is None: + random = self.random + _int = int + for i in reversed(xrange(1, len(x))): + # pick an element in x[:i+1] with which to exchange x[i] + j = _int(random() * (i+1)) + x[i], x[j] = x[j], x[i] + + def sample(self, population, k): + """Chooses k unique random elements from a population sequence. + + Returns a new list containing elements from the population while + leaving the original population unchanged. The resulting list is + in selection order so that all sub-slices will also be valid random + samples. This allows raffle winners (the sample) to be partitioned + into grand prize and second place winners (the subslices). + + Members of the population need not be hashable or unique. If the + population contains repeats, then each occurrence is a possible + selection in the sample. + + To choose a sample in a range of integers, use xrange as an argument. + This is especially fast and space efficient for sampling from a + large population: sample(xrange(10000000), 60) + """ + + # Sampling without replacement entails tracking either potential + # selections (the pool) in a list or previous selections in a set. + + # When the number of selections is small compared to the + # population, then tracking selections is efficient, requiring + # only a small set and an occasional reselection. For + # a larger number of selections, the pool tracking method is + # preferred since the list takes less space than the + # set and it doesn't suffer from frequent reselections. + + n = len(population) + if not 0 <= k <= n: + raise ValueError("sample larger than population") + random = self.random + _int = int + result = [None] * k + setsize = 21 # size of a small set minus size of an empty list + if k > 5: + setsize += 4 ** _ceil(_log(k * 3, 4)) # table size for big sets + if n <= setsize or hasattr(population, "keys"): + # An n-length list is smaller than a k-length set, or this is a + # mapping type so the other algorithm wouldn't work. + pool = list(population) + for i in xrange(k): # invariant: non-selected at [0,n-i) + j = _int(random() * (n-i)) + result[i] = pool[j] + pool[j] = pool[n-i-1] # move non-selected item into vacancy + else: + try: + selected = set() + selected_add = selected.add + for i in xrange(k): + j = _int(random() * n) + while j in selected: + j = _int(random() * n) + selected_add(j) + result[i] = population[j] + except (TypeError, KeyError): # handle (at least) sets + if isinstance(population, list): + raise + return self.sample(tuple(population), k) + return result + +## -------------------- real-valued distributions ------------------- + +## -------------------- uniform distribution ------------------- + + def uniform(self, a, b): + "Get a random number in the range [a, b) or [a, b] depending on rounding." + return a + (b-a) * self.random() + +## -------------------- triangular -------------------- + + def triangular(self, low=0.0, high=1.0, mode=None): + """Triangular distribution. + + Continuous distribution bounded by given lower and upper limits, + and having a given mode value in-between. + + http://en.wikipedia.org/wiki/Triangular_distribution + + """ + u = self.random() + try: + c = 0.5 if mode is None else (mode - low) / (high - low) + except ZeroDivisionError: + return low + if u > c: + u = 1.0 - u + c = 1.0 - c + low, high = high, low + return low + (high - low) * (u * c) ** 0.5 + +## -------------------- normal distribution -------------------- + + def normalvariate(self, mu, sigma): + """Normal distribution. + + mu is the mean, and sigma is the standard deviation. + + """ + # mu = mean, sigma = standard deviation + + # Uses Kinderman and Monahan method. Reference: Kinderman, + # A.J. and Monahan, J.F., "Computer generation of random + # variables using the ratio of uniform deviates", ACM Trans + # Math Software, 3, (1977), pp257-260. + + random = self.random + while 1: + u1 = random() + u2 = 1.0 - random() + z = NV_MAGICCONST*(u1-0.5)/u2 + zz = z*z/4.0 + if zz <= -_log(u2): + break + return mu + z*sigma + +## -------------------- lognormal distribution -------------------- + + def lognormvariate(self, mu, sigma): + """Log normal distribution. + + If you take the natural logarithm of this distribution, you'll get a + normal distribution with mean mu and standard deviation sigma. + mu can have any value, and sigma must be greater than zero. + + """ + return _exp(self.normalvariate(mu, sigma)) + +## -------------------- exponential distribution -------------------- + + def expovariate(self, lambd): + """Exponential distribution. + + lambd is 1.0 divided by the desired mean. It should be + nonzero. (The parameter would be called "lambda", but that is + a reserved word in Python.) Returned values range from 0 to + positive infinity if lambd is positive, and from negative + infinity to 0 if lambd is negative. + + """ + # lambd: rate lambd = 1/mean + # ('lambda' is a Python reserved word) + + # we use 1-random() instead of random() to preclude the + # possibility of taking the log of zero. + return -_log(1.0 - self.random())/lambd + +## -------------------- von Mises distribution -------------------- + + def vonmisesvariate(self, mu, kappa): + """Circular data distribution. + + mu is the mean angle, expressed in radians between 0 and 2*pi, and + kappa is the concentration parameter, which must be greater than or + equal to zero. If kappa is equal to zero, this distribution reduces + to a uniform random angle over the range 0 to 2*pi. + + """ + # mu: mean angle (in radians between 0 and 2*pi) + # kappa: concentration parameter kappa (>= 0) + # if kappa = 0 generate uniform random angle + + # Based upon an algorithm published in: Fisher, N.I., + # "Statistical Analysis of Circular Data", Cambridge + # University Press, 1993. + + # Thanks to Magnus Kessler for a correction to the + # implementation of step 4. + + random = self.random + if kappa <= 1e-6: + return TWOPI * random() + + s = 0.5 / kappa + r = s + _sqrt(1.0 + s * s) + + while 1: + u1 = random() + z = _cos(_pi * u1) + + d = z / (r + z) + u2 = random() + if u2 < 1.0 - d * d or u2 <= (1.0 - d) * _exp(d): + break + + q = 1.0 / r + f = (q + z) / (1.0 + q * z) + u3 = random() + if u3 > 0.5: + theta = (mu + _acos(f)) % TWOPI + else: + theta = (mu - _acos(f)) % TWOPI + + return theta + +## -------------------- gamma distribution -------------------- + + def gammavariate(self, alpha, beta): + """Gamma distribution. Not the gamma function! + + Conditions on the parameters are alpha > 0 and beta > 0. + + The probability distribution function is: + + x ** (alpha - 1) * math.exp(-x / beta) + pdf(x) = -------------------------------------- + math.gamma(alpha) * beta ** alpha + + """ + + # alpha > 0, beta > 0, mean is alpha*beta, variance is alpha*beta**2 + + # Warning: a few older sources define the gamma distribution in terms + # of alpha > -1.0 + if alpha <= 0.0 or beta <= 0.0: + raise ValueError, 'gammavariate: alpha and beta must be > 0.0' + + random = self.random + if alpha > 1.0: + + # Uses R.C.H. Cheng, "The generation of Gamma + # variables with non-integral shape parameters", + # Applied Statistics, (1977), 26, No. 1, p71-74 + + ainv = _sqrt(2.0 * alpha - 1.0) + bbb = alpha - LOG4 + ccc = alpha + ainv + + while 1: + u1 = random() + if not 1e-7 < u1 < .9999999: + continue + u2 = 1.0 - random() + v = _log(u1/(1.0-u1))/ainv + x = alpha*_exp(v) + z = u1*u1*u2 + r = bbb+ccc*v-x + if r + SG_MAGICCONST - 4.5*z >= 0.0 or r >= _log(z): + return x * beta + + elif alpha == 1.0: + # expovariate(1) + u = random() + while u <= 1e-7: + u = random() + return -_log(u) * beta + + else: # alpha is between 0 and 1 (exclusive) + + # Uses ALGORITHM GS of Statistical Computing - Kennedy & Gentle + + while 1: + u = random() + b = (_e + alpha)/_e + p = b*u + if p <= 1.0: + x = p ** (1.0/alpha) + else: + x = -_log((b-p)/alpha) + u1 = random() + if p > 1.0: + if u1 <= x ** (alpha - 1.0): + break + elif u1 <= _exp(-x): + break + return x * beta + +## -------------------- Gauss (faster alternative) -------------------- + + def gauss(self, mu, sigma): + """Gaussian distribution. + + mu is the mean, and sigma is the standard deviation. This is + slightly faster than the normalvariate() function. + + Not thread-safe without a lock around calls. + + """ + + # When x and y are two variables from [0, 1), uniformly + # distributed, then + # + # cos(2*pi*x)*sqrt(-2*log(1-y)) + # sin(2*pi*x)*sqrt(-2*log(1-y)) + # + # are two *independent* variables with normal distribution + # (mu = 0, sigma = 1). + # (Lambert Meertens) + # (corrected version; bug discovered by Mike Miller, fixed by LM) + + # Multithreading note: When two threads call this function + # simultaneously, it is possible that they will receive the + # same return value. The window is very small though. To + # avoid this, you have to use a lock around all calls. (I + # didn't want to slow this down in the serial case by using a + # lock here.) + + random = self.random + z = self.gauss_next + self.gauss_next = None + if z is None: + x2pi = random() * TWOPI + g2rad = _sqrt(-2.0 * _log(1.0 - random())) + z = _cos(x2pi) * g2rad + self.gauss_next = _sin(x2pi) * g2rad + + return mu + z*sigma + +## -------------------- beta -------------------- +## See +## http://mail.python.org/pipermail/python-bugs-list/2001-January/003752.html +## for Ivan Frohne's insightful analysis of why the original implementation: +## +## def betavariate(self, alpha, beta): +## # Discrete Event Simulation in C, pp 87-88. +## +## y = self.expovariate(alpha) +## z = self.expovariate(1.0/beta) +## return z/(y+z) +## +## was dead wrong, and how it probably got that way. + + def betavariate(self, alpha, beta): + """Beta distribution. + + Conditions on the parameters are alpha > 0 and beta > 0. + Returned values range between 0 and 1. + + """ + + # This version due to Janne Sinkkonen, and matches all the std + # texts (e.g., Knuth Vol 2 Ed 3 pg 134 "the beta distribution"). + y = self.gammavariate(alpha, 1.) + if y == 0: + return 0.0 + else: + return y / (y + self.gammavariate(beta, 1.)) + +## -------------------- Pareto -------------------- + + def paretovariate(self, alpha): + """Pareto distribution. alpha is the shape parameter.""" + # Jain, pg. 495 + + u = 1.0 - self.random() + return 1.0 / pow(u, 1.0/alpha) + +## -------------------- Weibull -------------------- + + def weibullvariate(self, alpha, beta): + """Weibull distribution. + + alpha is the scale parameter and beta is the shape parameter. + + """ + # Jain, pg. 499; bug fix courtesy Bill Arms + + u = 1.0 - self.random() + return alpha * pow(-_log(u), 1.0/beta) + +## -------------------- Wichmann-Hill ------------------- + +class WichmannHill(Random): + + VERSION = 1 # used by getstate/setstate + + def seed(self, a=None): + """Initialize internal state from hashable object. + + None or no argument seeds from current time or from an operating + system specific randomness source if available. + + If a is not None or an int or long, hash(a) is used instead. + + If a is an int or long, a is used directly. Distinct values between + 0 and 27814431486575L inclusive are guaranteed to yield distinct + internal states (this guarantee is specific to the default + Wichmann-Hill generator). + """ + + if a is None: + try: + a = long(_hexlify(_urandom(16)), 16) + except NotImplementedError: + import time + a = long(time.time() * 256) # use fractional seconds + + if not isinstance(a, (int, long)): + a = hash(a) + + a, x = divmod(a, 30268) + a, y = divmod(a, 30306) + a, z = divmod(a, 30322) + self._seed = int(x)+1, int(y)+1, int(z)+1 + + self.gauss_next = None + + def random(self): + """Get the next random number in the range [0.0, 1.0).""" + + # Wichman-Hill random number generator. + # + # Wichmann, B. A. & Hill, I. D. (1982) + # Algorithm AS 183: + # An efficient and portable pseudo-random number generator + # Applied Statistics 31 (1982) 188-190 + # + # see also: + # Correction to Algorithm AS 183 + # Applied Statistics 33 (1984) 123 + # + # McLeod, A. I. (1985) + # A remark on Algorithm AS 183 + # Applied Statistics 34 (1985),198-200 + + # This part is thread-unsafe: + # BEGIN CRITICAL SECTION + x, y, z = self._seed + x = (171 * x) % 30269 + y = (172 * y) % 30307 + z = (170 * z) % 30323 + self._seed = x, y, z + # END CRITICAL SECTION + + # Note: on a platform using IEEE-754 double arithmetic, this can + # never return 0.0 (asserted by Tim; proof too long for a comment). + return (x/30269.0 + y/30307.0 + z/30323.0) % 1.0 + + def getstate(self): + """Return internal state; can be passed to setstate() later.""" + return self.VERSION, self._seed, self.gauss_next + + def setstate(self, state): + """Restore internal state from object returned by getstate().""" + version = state[0] + if version == 1: + version, self._seed, self.gauss_next = state + else: + raise ValueError("state with version %s passed to " + "Random.setstate() of version %s" % + (version, self.VERSION)) + + def jumpahead(self, n): + """Act as if n calls to random() were made, but quickly. + + n is an int, greater than or equal to 0. + + Example use: If you have 2 threads and know that each will + consume no more than a million random numbers, create two Random + objects r1 and r2, then do + r2.setstate(r1.getstate()) + r2.jumpahead(1000000) + Then r1 and r2 will use guaranteed-disjoint segments of the full + period. + """ + + if not n >= 0: + raise ValueError("n must be >= 0") + x, y, z = self._seed + x = int(x * pow(171, n, 30269)) % 30269 + y = int(y * pow(172, n, 30307)) % 30307 + z = int(z * pow(170, n, 30323)) % 30323 + self._seed = x, y, z + + def __whseed(self, x=0, y=0, z=0): + """Set the Wichmann-Hill seed from (x, y, z). + + These must be integers in the range [0, 256). + """ + + if not type(x) == type(y) == type(z) == int: + raise TypeError('seeds must be integers') + if not (0 <= x < 256 and 0 <= y < 256 and 0 <= z < 256): + raise ValueError('seeds must be in range(0, 256)') + if 0 == x == y == z: + # Initialize from current time + import time + t = long(time.time() * 256) + t = int((t&0xffffff) ^ (t>>24)) + t, x = divmod(t, 256) + t, y = divmod(t, 256) + t, z = divmod(t, 256) + # Zero is a poor seed, so substitute 1 + self._seed = (x or 1, y or 1, z or 1) + + self.gauss_next = None + + def whseed(self, a=None): + """Seed from hashable object's hash code. + + None or no argument seeds from current time. It is not guaranteed + that objects with distinct hash codes lead to distinct internal + states. + + This is obsolete, provided for compatibility with the seed routine + used prior to Python 2.1. Use the .seed() method instead. + """ + + if a is None: + self.__whseed() + return + a = hash(a) + a, x = divmod(a, 256) + a, y = divmod(a, 256) + a, z = divmod(a, 256) + x = (x + a) % 256 or 1 + y = (y + a) % 256 or 1 + z = (z + a) % 256 or 1 + self.__whseed(x, y, z) + +## --------------- Operating System Random Source ------------------ + +class SystemRandom(Random): + """Alternate random number generator using sources provided + by the operating system (such as /dev/urandom on Unix or + CryptGenRandom on Windows). + + Not available on all systems (see os.urandom() for details). + """ + + def random(self): + """Get the next random number in the range [0.0, 1.0).""" + return (long(_hexlify(_urandom(7)), 16) >> 3) * RECIP_BPF + + def getrandbits(self, k): + """getrandbits(k) -> x. Generates a long int with k random bits.""" + if k <= 0: + raise ValueError('number of bits must be greater than zero') + if k != int(k): + raise TypeError('number of bits should be an integer') + bytes = (k + 7) // 8 # bits / 8 and rounded up + x = long(_hexlify(_urandom(bytes)), 16) + return x >> (bytes * 8 - k) # trim excess bits + + def _stub(self, *args, **kwds): + "Stub method. Not used for a system random number generator." + return None + seed = jumpahead = _stub + + def _notimplemented(self, *args, **kwds): + "Method should not be called for a system random number generator." + raise NotImplementedError('System entropy source does not have state.') + getstate = setstate = _notimplemented + +## -------------------- test program -------------------- + +def _test_generator(n, func, args): + import time + print n, 'times', func.__name__ + total = 0.0 + sqsum = 0.0 + smallest = 1e10 + largest = -1e10 + t0 = time.time() + for i in range(n): + x = func(*args) + total += x + sqsum = sqsum + x*x + smallest = min(x, smallest) + largest = max(x, largest) + t1 = time.time() + print round(t1-t0, 3), 'sec,', + avg = total/n + stddev = _sqrt(sqsum/n - avg*avg) + print 'avg %g, stddev %g, min %g, max %g' % \ + (avg, stddev, smallest, largest) + + +def _test(N=2000): + _test_generator(N, random, ()) + _test_generator(N, normalvariate, (0.0, 1.0)) + _test_generator(N, lognormvariate, (0.0, 1.0)) + _test_generator(N, vonmisesvariate, (0.0, 1.0)) + _test_generator(N, gammavariate, (0.01, 1.0)) + _test_generator(N, gammavariate, (0.1, 1.0)) + _test_generator(N, gammavariate, (0.1, 2.0)) + _test_generator(N, gammavariate, (0.5, 1.0)) + _test_generator(N, gammavariate, (0.9, 1.0)) + _test_generator(N, gammavariate, (1.0, 1.0)) + _test_generator(N, gammavariate, (2.0, 1.0)) + _test_generator(N, gammavariate, (20.0, 1.0)) + _test_generator(N, gammavariate, (200.0, 1.0)) + _test_generator(N, gauss, (0.0, 1.0)) + _test_generator(N, betavariate, (3.0, 3.0)) + _test_generator(N, triangular, (0.0, 1.0, 1.0/3.0)) + +# Create one instance, seeded from current time, and export its methods +# as module-level functions. The functions share state across all uses +#(both in the user's code and in the Python libraries), but that's fine +# for most programs and is easier for the casual user than making them +# instantiate their own Random() instance. + +_inst = Random() +seed = _inst.seed +random = _inst.random +uniform = _inst.uniform +triangular = _inst.triangular +randint = _inst.randint +choice = _inst.choice +randrange = _inst.randrange +sample = _inst.sample +shuffle = _inst.shuffle +normalvariate = _inst.normalvariate +lognormvariate = _inst.lognormvariate +expovariate = _inst.expovariate +vonmisesvariate = _inst.vonmisesvariate +gammavariate = _inst.gammavariate +gauss = _inst.gauss +betavariate = _inst.betavariate +paretovariate = _inst.paretovariate +weibullvariate = _inst.weibullvariate +getstate = _inst.getstate +setstate = _inst.setstate +jumpahead = _inst.jumpahead +getrandbits = _inst.getrandbits + +if __name__ == '__main__': + _test() diff --git a/rfc822.py b/rfc822.py new file mode 100644 index 0000000..c1d0865 --- /dev/null +++ b/rfc822.py @@ -0,0 +1,1016 @@ +"""RFC 2822 message manipulation. + +Note: This is only a very rough sketch of a full RFC-822 parser; in particular +the tokenizing of addresses does not adhere to all the quoting rules. + +Note: RFC 2822 is a long awaited update to RFC 822. This module should +conform to RFC 2822, and is thus mis-named (it's not worth renaming it). Some +effort at RFC 2822 updates have been made, but a thorough audit has not been +performed. Consider any RFC 2822 non-conformance to be a bug. + + RFC 2822: http://www.faqs.org/rfcs/rfc2822.html + RFC 822 : http://www.faqs.org/rfcs/rfc822.html (obsolete) + +Directions for use: + +To create a Message object: first open a file, e.g.: + + fp = open(file, 'r') + +You can use any other legal way of getting an open file object, e.g. use +sys.stdin or call os.popen(). Then pass the open file object to the Message() +constructor: + + m = Message(fp) + +This class can work with any input object that supports a readline method. If +the input object has seek and tell capability, the rewindbody method will +work; also illegal lines will be pushed back onto the input stream. If the +input object lacks seek but has an `unread' method that can push back a line +of input, Message will use that to push back illegal lines. Thus this class +can be used to parse messages coming from a buffered stream. + +The optional `seekable' argument is provided as a workaround for certain stdio +libraries in which tell() discards buffered data before discovering that the +lseek() system call doesn't work. For maximum portability, you should set the +seekable argument to zero to prevent that initial \code{tell} when passing in +an unseekable object such as a file object created from a socket object. If +it is 1 on entry -- which it is by default -- the tell() method of the open +file object is called once; if this raises an exception, seekable is reset to +0. For other nonzero values of seekable, this test is not made. + +To get the text of a particular header there are several methods: + + str = m.getheader(name) + str = m.getrawheader(name) + +where name is the name of the header, e.g. 'Subject'. The difference is that +getheader() strips the leading and trailing whitespace, while getrawheader() +doesn't. Both functions retain embedded whitespace (including newlines) +exactly as they are specified in the header, and leave the case of the text +unchanged. + +For addresses and address lists there are functions + + realname, mailaddress = m.getaddr(name) + list = m.getaddrlist(name) + +where the latter returns a list of (realname, mailaddr) tuples. + +There is also a method + + time = m.getdate(name) + +which parses a Date-like field and returns a time-compatible tuple, +i.e. a tuple such as returned by time.localtime() or accepted by +time.mktime(). + +See the class definition for lower level access methods. + +There are also some utility functions here. +""" +# Cleanup and extensions by Eric S. Raymond + +import time + +from warnings import warnpy3k +warnpy3k("in 3.x, rfc822 has been removed in favor of the email package", + stacklevel=2) + +__all__ = ["Message","AddressList","parsedate","parsedate_tz","mktime_tz"] + +_blanklines = ('\r\n', '\n') # Optimization for islast() + + +class Message: + """Represents a single RFC 2822-compliant message.""" + + def __init__(self, fp, seekable = 1): + """Initialize the class instance and read the headers.""" + if seekable == 1: + # Exercise tell() to make sure it works + # (and then assume seek() works, too) + try: + fp.tell() + except (AttributeError, IOError): + seekable = 0 + self.fp = fp + self.seekable = seekable + self.startofheaders = None + self.startofbody = None + # + if self.seekable: + try: + self.startofheaders = self.fp.tell() + except IOError: + self.seekable = 0 + # + self.readheaders() + # + if self.seekable: + try: + self.startofbody = self.fp.tell() + except IOError: + self.seekable = 0 + + def rewindbody(self): + """Rewind the file to the start of the body (if seekable).""" + if not self.seekable: + raise IOError, "unseekable file" + self.fp.seek(self.startofbody) + + def readheaders(self): + """Read header lines. + + Read header lines up to the entirely blank line that terminates them. + The (normally blank) line that ends the headers is skipped, but not + included in the returned list. If a non-header line ends the headers, + (which is an error), an attempt is made to backspace over it; it is + never included in the returned list. + + The variable self.status is set to the empty string if all went well, + otherwise it is an error message. The variable self.headers is a + completely uninterpreted list of lines contained in the header (so + printing them will reproduce the header exactly as it appears in the + file). + """ + self.dict = {} + self.unixfrom = '' + self.headers = lst = [] + self.status = '' + headerseen = "" + firstline = 1 + startofline = unread = tell = None + if hasattr(self.fp, 'unread'): + unread = self.fp.unread + elif self.seekable: + tell = self.fp.tell + while 1: + if tell: + try: + startofline = tell() + except IOError: + startofline = tell = None + self.seekable = 0 + line = self.fp.readline() + if not line: + self.status = 'EOF in headers' + break + # Skip unix From name time lines + if firstline and line.startswith('From '): + self.unixfrom = self.unixfrom + line + continue + firstline = 0 + if headerseen and line[0] in ' \t': + # It's a continuation line. + lst.append(line) + x = (self.dict[headerseen] + "\n " + line.strip()) + self.dict[headerseen] = x.strip() + continue + elif self.iscomment(line): + # It's a comment. Ignore it. + continue + elif self.islast(line): + # Note! No pushback here! The delimiter line gets eaten. + break + headerseen = self.isheader(line) + if headerseen: + # It's a legal header line, save it. + lst.append(line) + self.dict[headerseen] = line[len(headerseen)+1:].strip() + continue + elif headerseen is not None: + # An empty header name. These aren't allowed in HTTP, but it's + # probably a benign mistake. Don't add the header, just keep + # going. + continue + else: + # It's not a header line; throw it back and stop here. + if not self.dict: + self.status = 'No headers' + else: + self.status = 'Non-header line where header expected' + # Try to undo the read. + if unread: + unread(line) + elif tell: + self.fp.seek(startofline) + else: + self.status = self.status + '; bad seek' + break + + def isheader(self, line): + """Determine whether a given line is a legal header. + + This method should return the header name, suitably canonicalized. + You may override this method in order to use Message parsing on tagged + data in RFC 2822-like formats with special header formats. + """ + i = line.find(':') + if i > -1: + return line[:i].lower() + return None + + def islast(self, line): + """Determine whether a line is a legal end of RFC 2822 headers. + + You may override this method if your application wants to bend the + rules, e.g. to strip trailing whitespace, or to recognize MH template + separators ('--------'). For convenience (e.g. for code reading from + sockets) a line consisting of \\r\\n also matches. + """ + return line in _blanklines + + def iscomment(self, line): + """Determine whether a line should be skipped entirely. + + You may override this method in order to use Message parsing on tagged + data in RFC 2822-like formats that support embedded comments or + free-text data. + """ + return False + + def getallmatchingheaders(self, name): + """Find all header lines matching a given header name. + + Look through the list of headers and find all lines matching a given + header name (and their continuation lines). A list of the lines is + returned, without interpretation. If the header does not occur, an + empty list is returned. If the header occurs multiple times, all + occurrences are returned. Case is not important in the header name. + """ + name = name.lower() + ':' + n = len(name) + lst = [] + hit = 0 + for line in self.headers: + if line[:n].lower() == name: + hit = 1 + elif not line[:1].isspace(): + hit = 0 + if hit: + lst.append(line) + return lst + + def getfirstmatchingheader(self, name): + """Get the first header line matching name. + + This is similar to getallmatchingheaders, but it returns only the + first matching header (and its continuation lines). + """ + name = name.lower() + ':' + n = len(name) + lst = [] + hit = 0 + for line in self.headers: + if hit: + if not line[:1].isspace(): + break + elif line[:n].lower() == name: + hit = 1 + if hit: + lst.append(line) + return lst + + def getrawheader(self, name): + """A higher-level interface to getfirstmatchingheader(). + + Return a string containing the literal text of the header but with the + keyword stripped. All leading, trailing and embedded whitespace is + kept in the string, however. Return None if the header does not + occur. + """ + + lst = self.getfirstmatchingheader(name) + if not lst: + return None + lst[0] = lst[0][len(name) + 1:] + return ''.join(lst) + + def getheader(self, name, default=None): + """Get the header value for a name. + + This is the normal interface: it returns a stripped version of the + header value for a given header name, or None if it doesn't exist. + This uses the dictionary version which finds the *last* such header. + """ + return self.dict.get(name.lower(), default) + get = getheader + + def getheaders(self, name): + """Get all values for a header. + + This returns a list of values for headers given more than once; each + value in the result list is stripped in the same way as the result of + getheader(). If the header is not given, return an empty list. + """ + result = [] + current = '' + have_header = 0 + for s in self.getallmatchingheaders(name): + if s[0].isspace(): + if current: + current = "%s\n %s" % (current, s.strip()) + else: + current = s.strip() + else: + if have_header: + result.append(current) + current = s[s.find(":") + 1:].strip() + have_header = 1 + if have_header: + result.append(current) + return result + + def getaddr(self, name): + """Get a single address from a header, as a tuple. + + An example return value: + ('Guido van Rossum', 'guido@cwi.nl') + """ + # New, by Ben Escoto + alist = self.getaddrlist(name) + if alist: + return alist[0] + else: + return (None, None) + + def getaddrlist(self, name): + """Get a list of addresses from a header. + + Retrieves a list of addresses from a header, where each address is a + tuple as returned by getaddr(). Scans all named headers, so it works + properly with multiple To: or Cc: headers for example. + """ + raw = [] + for h in self.getallmatchingheaders(name): + if h[0] in ' \t': + raw.append(h) + else: + if raw: + raw.append(', ') + i = h.find(':') + if i > 0: + addr = h[i+1:] + raw.append(addr) + alladdrs = ''.join(raw) + a = AddressList(alladdrs) + return a.addresslist + + def getdate(self, name): + """Retrieve a date field from a header. + + Retrieves a date field from the named header, returning a tuple + compatible with time.mktime(). + """ + try: + data = self[name] + except KeyError: + return None + return parsedate(data) + + def getdate_tz(self, name): + """Retrieve a date field from a header as a 10-tuple. + + The first 9 elements make up a tuple compatible with time.mktime(), + and the 10th is the offset of the poster's time zone from GMT/UTC. + """ + try: + data = self[name] + except KeyError: + return None + return parsedate_tz(data) + + + # Access as a dictionary (only finds *last* header of each type): + + def __len__(self): + """Get the number of headers in a message.""" + return len(self.dict) + + def __getitem__(self, name): + """Get a specific header, as from a dictionary.""" + return self.dict[name.lower()] + + def __setitem__(self, name, value): + """Set the value of a header. + + Note: This is not a perfect inversion of __getitem__, because any + changed headers get stuck at the end of the raw-headers list rather + than where the altered header was. + """ + del self[name] # Won't fail if it doesn't exist + self.dict[name.lower()] = value + text = name + ": " + value + for line in text.split("\n"): + self.headers.append(line + "\n") + + def __delitem__(self, name): + """Delete all occurrences of a specific header, if it is present.""" + name = name.lower() + if not name in self.dict: + return + del self.dict[name] + name = name + ':' + n = len(name) + lst = [] + hit = 0 + for i in range(len(self.headers)): + line = self.headers[i] + if line[:n].lower() == name: + hit = 1 + elif not line[:1].isspace(): + hit = 0 + if hit: + lst.append(i) + for i in reversed(lst): + del self.headers[i] + + def setdefault(self, name, default=""): + lowername = name.lower() + if lowername in self.dict: + return self.dict[lowername] + else: + text = name + ": " + default + for line in text.split("\n"): + self.headers.append(line + "\n") + self.dict[lowername] = default + return default + + def has_key(self, name): + """Determine whether a message contains the named header.""" + return name.lower() in self.dict + + def __contains__(self, name): + """Determine whether a message contains the named header.""" + return name.lower() in self.dict + + def __iter__(self): + return iter(self.dict) + + def keys(self): + """Get all of a message's header field names.""" + return self.dict.keys() + + def values(self): + """Get all of a message's header field values.""" + return self.dict.values() + + def items(self): + """Get all of a message's headers. + + Returns a list of name, value tuples. + """ + return self.dict.items() + + def __str__(self): + return ''.join(self.headers) + + +# Utility functions +# ----------------- + +# XXX Should fix unquote() and quote() to be really conformant. +# XXX The inverses of the parse functions may also be useful. + + +def unquote(s): + """Remove quotes from a string.""" + if len(s) > 1: + if s.startswith('"') and s.endswith('"'): + return s[1:-1].replace('\\\\', '\\').replace('\\"', '"') + if s.startswith('<') and s.endswith('>'): + return s[1:-1] + return s + + +def quote(s): + """Add quotes around a string.""" + return s.replace('\\', '\\\\').replace('"', '\\"') + + +def parseaddr(address): + """Parse an address into a (realname, mailaddr) tuple.""" + a = AddressList(address) + lst = a.addresslist + if not lst: + return (None, None) + return lst[0] + + +class AddrlistClass: + """Address parser class by Ben Escoto. + + To understand what this class does, it helps to have a copy of + RFC 2822 in front of you. + + http://www.faqs.org/rfcs/rfc2822.html + + Note: this class interface is deprecated and may be removed in the future. + Use rfc822.AddressList instead. + """ + + def __init__(self, field): + """Initialize a new instance. + + `field' is an unparsed address header field, containing one or more + addresses. + """ + self.specials = '()<>@,:;.\"[]' + self.pos = 0 + self.LWS = ' \t' + self.CR = '\r\n' + self.atomends = self.specials + self.LWS + self.CR + # Note that RFC 2822 now specifies `.' as obs-phrase, meaning that it + # is obsolete syntax. RFC 2822 requires that we recognize obsolete + # syntax, so allow dots in phrases. + self.phraseends = self.atomends.replace('.', '') + self.field = field + self.commentlist = [] + + def gotonext(self): + """Parse up to the start of the next address.""" + while self.pos < len(self.field): + if self.field[self.pos] in self.LWS + '\n\r': + self.pos = self.pos + 1 + elif self.field[self.pos] == '(': + self.commentlist.append(self.getcomment()) + else: break + + def getaddrlist(self): + """Parse all addresses. + + Returns a list containing all of the addresses. + """ + result = [] + ad = self.getaddress() + while ad: + result += ad + ad = self.getaddress() + return result + + def getaddress(self): + """Parse the next address.""" + self.commentlist = [] + self.gotonext() + + oldpos = self.pos + oldcl = self.commentlist + plist = self.getphraselist() + + self.gotonext() + returnlist = [] + + if self.pos >= len(self.field): + # Bad email address technically, no domain. + if plist: + returnlist = [(' '.join(self.commentlist), plist[0])] + + elif self.field[self.pos] in '.@': + # email address is just an addrspec + # this isn't very efficient since we start over + self.pos = oldpos + self.commentlist = oldcl + addrspec = self.getaddrspec() + returnlist = [(' '.join(self.commentlist), addrspec)] + + elif self.field[self.pos] == ':': + # address is a group + returnlist = [] + + fieldlen = len(self.field) + self.pos += 1 + while self.pos < len(self.field): + self.gotonext() + if self.pos < fieldlen and self.field[self.pos] == ';': + self.pos += 1 + break + returnlist = returnlist + self.getaddress() + + elif self.field[self.pos] == '<': + # Address is a phrase then a route addr + routeaddr = self.getrouteaddr() + + if self.commentlist: + returnlist = [(' '.join(plist) + ' (' + \ + ' '.join(self.commentlist) + ')', routeaddr)] + else: returnlist = [(' '.join(plist), routeaddr)] + + else: + if plist: + returnlist = [(' '.join(self.commentlist), plist[0])] + elif self.field[self.pos] in self.specials: + self.pos += 1 + + self.gotonext() + if self.pos < len(self.field) and self.field[self.pos] == ',': + self.pos += 1 + return returnlist + + def getrouteaddr(self): + """Parse a route address (Return-path value). + + This method just skips all the route stuff and returns the addrspec. + """ + if self.field[self.pos] != '<': + return + + expectroute = 0 + self.pos += 1 + self.gotonext() + adlist = "" + while self.pos < len(self.field): + if expectroute: + self.getdomain() + expectroute = 0 + elif self.field[self.pos] == '>': + self.pos += 1 + break + elif self.field[self.pos] == '@': + self.pos += 1 + expectroute = 1 + elif self.field[self.pos] == ':': + self.pos += 1 + else: + adlist = self.getaddrspec() + self.pos += 1 + break + self.gotonext() + + return adlist + + def getaddrspec(self): + """Parse an RFC 2822 addr-spec.""" + aslist = [] + + self.gotonext() + while self.pos < len(self.field): + if self.field[self.pos] == '.': + aslist.append('.') + self.pos += 1 + elif self.field[self.pos] == '"': + aslist.append('"%s"' % self.getquote()) + elif self.field[self.pos] in self.atomends: + break + else: aslist.append(self.getatom()) + self.gotonext() + + if self.pos >= len(self.field) or self.field[self.pos] != '@': + return ''.join(aslist) + + aslist.append('@') + self.pos += 1 + self.gotonext() + return ''.join(aslist) + self.getdomain() + + def getdomain(self): + """Get the complete domain name from an address.""" + sdlist = [] + while self.pos < len(self.field): + if self.field[self.pos] in self.LWS: + self.pos += 1 + elif self.field[self.pos] == '(': + self.commentlist.append(self.getcomment()) + elif self.field[self.pos] == '[': + sdlist.append(self.getdomainliteral()) + elif self.field[self.pos] == '.': + self.pos += 1 + sdlist.append('.') + elif self.field[self.pos] in self.atomends: + break + else: sdlist.append(self.getatom()) + return ''.join(sdlist) + + def getdelimited(self, beginchar, endchars, allowcomments = 1): + """Parse a header fragment delimited by special characters. + + `beginchar' is the start character for the fragment. If self is not + looking at an instance of `beginchar' then getdelimited returns the + empty string. + + `endchars' is a sequence of allowable end-delimiting characters. + Parsing stops when one of these is encountered. + + If `allowcomments' is non-zero, embedded RFC 2822 comments are allowed + within the parsed fragment. + """ + if self.field[self.pos] != beginchar: + return '' + + slist = [''] + quote = 0 + self.pos += 1 + while self.pos < len(self.field): + if quote == 1: + slist.append(self.field[self.pos]) + quote = 0 + elif self.field[self.pos] in endchars: + self.pos += 1 + break + elif allowcomments and self.field[self.pos] == '(': + slist.append(self.getcomment()) + continue # have already advanced pos from getcomment + elif self.field[self.pos] == '\\': + quote = 1 + else: + slist.append(self.field[self.pos]) + self.pos += 1 + + return ''.join(slist) + + def getquote(self): + """Get a quote-delimited fragment from self's field.""" + return self.getdelimited('"', '"\r', 0) + + def getcomment(self): + """Get a parenthesis-delimited fragment from self's field.""" + return self.getdelimited('(', ')\r', 1) + + def getdomainliteral(self): + """Parse an RFC 2822 domain-literal.""" + return '[%s]' % self.getdelimited('[', ']\r', 0) + + def getatom(self, atomends=None): + """Parse an RFC 2822 atom. + + Optional atomends specifies a different set of end token delimiters + (the default is to use self.atomends). This is used e.g. in + getphraselist() since phrase endings must not include the `.' (which + is legal in phrases).""" + atomlist = [''] + if atomends is None: + atomends = self.atomends + + while self.pos < len(self.field): + if self.field[self.pos] in atomends: + break + else: atomlist.append(self.field[self.pos]) + self.pos += 1 + + return ''.join(atomlist) + + def getphraselist(self): + """Parse a sequence of RFC 2822 phrases. + + A phrase is a sequence of words, which are in turn either RFC 2822 + atoms or quoted-strings. Phrases are canonicalized by squeezing all + runs of continuous whitespace into one space. + """ + plist = [] + + while self.pos < len(self.field): + if self.field[self.pos] in self.LWS: + self.pos += 1 + elif self.field[self.pos] == '"': + plist.append(self.getquote()) + elif self.field[self.pos] == '(': + self.commentlist.append(self.getcomment()) + elif self.field[self.pos] in self.phraseends: + break + else: + plist.append(self.getatom(self.phraseends)) + + return plist + +class AddressList(AddrlistClass): + """An AddressList encapsulates a list of parsed RFC 2822 addresses.""" + def __init__(self, field): + AddrlistClass.__init__(self, field) + if field: + self.addresslist = self.getaddrlist() + else: + self.addresslist = [] + + def __len__(self): + return len(self.addresslist) + + def __str__(self): + return ", ".join(map(dump_address_pair, self.addresslist)) + + def __add__(self, other): + # Set union + newaddr = AddressList(None) + newaddr.addresslist = self.addresslist[:] + for x in other.addresslist: + if not x in self.addresslist: + newaddr.addresslist.append(x) + return newaddr + + def __iadd__(self, other): + # Set union, in-place + for x in other.addresslist: + if not x in self.addresslist: + self.addresslist.append(x) + return self + + def __sub__(self, other): + # Set difference + newaddr = AddressList(None) + for x in self.addresslist: + if not x in other.addresslist: + newaddr.addresslist.append(x) + return newaddr + + def __isub__(self, other): + # Set difference, in-place + for x in other.addresslist: + if x in self.addresslist: + self.addresslist.remove(x) + return self + + def __getitem__(self, index): + # Make indexing, slices, and 'in' work + return self.addresslist[index] + +def dump_address_pair(pair): + """Dump a (name, address) pair in a canonicalized form.""" + if pair[0]: + return '"' + pair[0] + '" <' + pair[1] + '>' + else: + return pair[1] + +# Parse a date field + +_monthnames = ['jan', 'feb', 'mar', 'apr', 'may', 'jun', 'jul', + 'aug', 'sep', 'oct', 'nov', 'dec', + 'january', 'february', 'march', 'april', 'may', 'june', 'july', + 'august', 'september', 'october', 'november', 'december'] +_daynames = ['mon', 'tue', 'wed', 'thu', 'fri', 'sat', 'sun'] + +# The timezone table does not include the military time zones defined +# in RFC822, other than Z. According to RFC1123, the description in +# RFC822 gets the signs wrong, so we can't rely on any such time +# zones. RFC1123 recommends that numeric timezone indicators be used +# instead of timezone names. + +_timezones = {'UT':0, 'UTC':0, 'GMT':0, 'Z':0, + 'AST': -400, 'ADT': -300, # Atlantic (used in Canada) + 'EST': -500, 'EDT': -400, # Eastern + 'CST': -600, 'CDT': -500, # Central + 'MST': -700, 'MDT': -600, # Mountain + 'PST': -800, 'PDT': -700 # Pacific + } + + +def parsedate_tz(data): + """Convert a date string to a time tuple. + + Accounts for military timezones. + """ + if not data: + return None + data = data.split() + if data[0][-1] in (',', '.') or data[0].lower() in _daynames: + # There's a dayname here. Skip it + del data[0] + else: + # no space after the "weekday,"? + i = data[0].rfind(',') + if i >= 0: + data[0] = data[0][i+1:] + if len(data) == 3: # RFC 850 date, deprecated + stuff = data[0].split('-') + if len(stuff) == 3: + data = stuff + data[1:] + if len(data) == 4: + s = data[3] + i = s.find('+') + if i > 0: + data[3:] = [s[:i], s[i+1:]] + else: + data.append('') # Dummy tz + if len(data) < 5: + return None + data = data[:5] + [dd, mm, yy, tm, tz] = data + mm = mm.lower() + if not mm in _monthnames: + dd, mm = mm, dd.lower() + if not mm in _monthnames: + return None + mm = _monthnames.index(mm)+1 + if mm > 12: mm = mm - 12 + if dd[-1] == ',': + dd = dd[:-1] + i = yy.find(':') + if i > 0: + yy, tm = tm, yy + if yy[-1] == ',': + yy = yy[:-1] + if not yy[0].isdigit(): + yy, tz = tz, yy + if tm[-1] == ',': + tm = tm[:-1] + tm = tm.split(':') + if len(tm) == 2: + [thh, tmm] = tm + tss = '0' + elif len(tm) == 3: + [thh, tmm, tss] = tm + else: + return None + try: + yy = int(yy) + dd = int(dd) + thh = int(thh) + tmm = int(tmm) + tss = int(tss) + except ValueError: + return None + tzoffset = None + tz = tz.upper() + if tz in _timezones: + tzoffset = _timezones[tz] + else: + try: + tzoffset = int(tz) + except ValueError: + pass + # Convert a timezone offset into seconds ; -0500 -> -18000 + if tzoffset: + if tzoffset < 0: + tzsign = -1 + tzoffset = -tzoffset + else: + tzsign = 1 + tzoffset = tzsign * ( (tzoffset//100)*3600 + (tzoffset % 100)*60) + return (yy, mm, dd, thh, tmm, tss, 0, 1, 0, tzoffset) + + +def parsedate(data): + """Convert a time string to a time tuple.""" + t = parsedate_tz(data) + if t is None: + return t + return t[:9] + + +def mktime_tz(data): + """Turn a 10-tuple as returned by parsedate_tz() into a UTC timestamp.""" + if data[9] is None: + # No zone info, so localtime is better assumption than GMT + return time.mktime(data[:8] + (-1,)) + else: + t = time.mktime(data[:8] + (0,)) + return t - data[9] - time.timezone + +def formatdate(timeval=None): + """Returns time format preferred for Internet standards. + + Sun, 06 Nov 1994 08:49:37 GMT ; RFC 822, updated by RFC 1123 + + According to RFC 1123, day and month names must always be in + English. If not for that, this code could use strftime(). It + can't because strftime() honors the locale and could generated + non-English names. + """ + if timeval is None: + timeval = time.time() + timeval = time.gmtime(timeval) + return "%s, %02d %s %04d %02d:%02d:%02d GMT" % ( + ("Mon", "Tue", "Wed", "Thu", "Fri", "Sat", "Sun")[timeval[6]], + timeval[2], + ("Jan", "Feb", "Mar", "Apr", "May", "Jun", + "Jul", "Aug", "Sep", "Oct", "Nov", "Dec")[timeval[1]-1], + timeval[0], timeval[3], timeval[4], timeval[5]) + + +# When used as script, run a small test program. +# The first command line argument must be a filename containing one +# message in RFC-822 format. + +if __name__ == '__main__': + import sys, os + file = os.path.join(os.environ['HOME'], 'Mail/inbox/1') + if sys.argv[1:]: file = sys.argv[1] + f = open(file, 'r') + m = Message(f) + print 'From:', m.getaddr('from') + print 'To:', m.getaddrlist('to') + print 'Subject:', m.getheader('subject') + print 'Date:', m.getheader('date') + date = m.getdate_tz('date') + tz = date[-1] + date = time.localtime(mktime_tz(date)) + if date: + print 'ParsedDate:', time.asctime(date), + hhmmss = tz + hhmm, ss = divmod(hhmmss, 60) + hh, mm = divmod(hhmm, 60) + print "%+03d%02d" % (hh, mm), + if ss: print ".%02d" % ss, + print + else: + print 'ParsedDate:', None + m.rewindbody() + n = 0 + while f.readline(): + n += 1 + print 'Lines:', n + print '-'*70 + print 'len =', len(m) + if 'Date' in m: print 'Date =', m['Date'] + if 'X-Nonsense' in m: pass + print 'keys =', m.keys() + print 'values =', m.values() + print 'items =', m.items() diff --git a/socket.py b/socket.py new file mode 100644 index 0000000..614af29 --- /dev/null +++ b/socket.py @@ -0,0 +1,577 @@ +# Wrapper module for _socket, providing some additional facilities +# implemented in Python. + +"""\ +This module provides socket operations and some related functions. +On Unix, it supports IP (Internet Protocol) and Unix domain sockets. +On other systems, it only supports IP. Functions specific for a +socket are available as methods of the socket object. + +Functions: + +socket() -- create a new socket object +socketpair() -- create a pair of new socket objects [*] +fromfd() -- create a socket object from an open file descriptor [*] +gethostname() -- return the current hostname +gethostbyname() -- map a hostname to its IP number +gethostbyaddr() -- map an IP number or hostname to DNS info +getservbyname() -- map a service name and a protocol name to a port number +getprotobyname() -- map a protocol name (e.g. 'tcp') to a number +ntohs(), ntohl() -- convert 16, 32 bit int from network to host byte order +htons(), htonl() -- convert 16, 32 bit int from host to network byte order +inet_aton() -- convert IP addr string (123.45.67.89) to 32-bit packed format +inet_ntoa() -- convert 32-bit packed format IP to string (123.45.67.89) +ssl() -- secure socket layer support (only available if configured) +socket.getdefaulttimeout() -- get the default timeout value +socket.setdefaulttimeout() -- set the default timeout value +create_connection() -- connects to an address, with an optional timeout and + optional source address. + + [*] not available on all platforms! + +Special objects: + +SocketType -- type object for socket objects +error -- exception raised for I/O errors +has_ipv6 -- boolean value indicating if IPv6 is supported + +Integer constants: + +AF_INET, AF_UNIX -- socket domains (first argument to socket() call) +SOCK_STREAM, SOCK_DGRAM, SOCK_RAW -- socket types (second argument) + +Many other constants may be defined; these may be used in calls to +the setsockopt() and getsockopt() methods. +""" + +import _socket +from _socket import * +from functools import partial +from types import MethodType + +try: + import _ssl +except ImportError: + # no SSL support + pass +else: + def ssl(sock, keyfile=None, certfile=None): + # we do an internal import here because the ssl + # module imports the socket module + import ssl as _realssl + warnings.warn("socket.ssl() is deprecated. Use ssl.wrap_socket() instead.", + DeprecationWarning, stacklevel=2) + return _realssl.sslwrap_simple(sock, keyfile, certfile) + + # we need to import the same constants we used to... + from _ssl import SSLError as sslerror + from _ssl import \ + RAND_add, \ + RAND_status, \ + SSL_ERROR_ZERO_RETURN, \ + SSL_ERROR_WANT_READ, \ + SSL_ERROR_WANT_WRITE, \ + SSL_ERROR_WANT_X509_LOOKUP, \ + SSL_ERROR_SYSCALL, \ + SSL_ERROR_SSL, \ + SSL_ERROR_WANT_CONNECT, \ + SSL_ERROR_EOF, \ + SSL_ERROR_INVALID_ERROR_CODE + try: + from _ssl import RAND_egd + except ImportError: + # LibreSSL does not provide RAND_egd + pass + +import os, sys, warnings + +try: + from cStringIO import StringIO +except ImportError: + from StringIO import StringIO + +try: + import errno +except ImportError: + errno = None +EBADF = getattr(errno, 'EBADF', 9) +EINTR = getattr(errno, 'EINTR', 4) + +__all__ = ["getfqdn", "create_connection"] +__all__.extend(os._get_exports_list(_socket)) + + +_realsocket = socket + +# WSA error codes +if sys.platform.lower().startswith("win"): + errorTab = {} + errorTab[10004] = "The operation was interrupted." + errorTab[10009] = "A bad file handle was passed." + errorTab[10013] = "Permission denied." + errorTab[10014] = "A fault occurred on the network??" # WSAEFAULT + errorTab[10022] = "An invalid operation was attempted." + errorTab[10035] = "The socket operation would block" + errorTab[10036] = "A blocking operation is already in progress." + errorTab[10048] = "The network address is in use." + errorTab[10054] = "The connection has been reset." + errorTab[10058] = "The network has been shut down." + errorTab[10060] = "The operation timed out." + errorTab[10061] = "Connection refused." + errorTab[10063] = "The name is too long." + errorTab[10064] = "The host is down." + errorTab[10065] = "The host is unreachable." + __all__.append("errorTab") + + + +def getfqdn(name=''): + """Get fully qualified domain name from name. + + An empty argument is interpreted as meaning the local host. + + First the hostname returned by gethostbyaddr() is checked, then + possibly existing aliases. In case no FQDN is available, hostname + from gethostname() is returned. + """ + name = name.strip() + if not name or name == '0.0.0.0': + name = gethostname() + try: + hostname, aliases, ipaddrs = gethostbyaddr(name) + except error: + pass + else: + aliases.insert(0, hostname) + for name in aliases: + if '.' in name: + break + else: + name = hostname + return name + + +_socketmethods = ( + 'bind', 'connect', 'connect_ex', 'fileno', 'listen', + 'getpeername', 'getsockname', 'getsockopt', 'setsockopt', + 'sendall', 'setblocking', + 'settimeout', 'gettimeout', 'shutdown') + +if os.name == "nt": + _socketmethods = _socketmethods + ('ioctl',) + +if sys.platform == "riscos": + _socketmethods = _socketmethods + ('sleeptaskw',) + +# All the method names that must be delegated to either the real socket +# object or the _closedsocket object. +_delegate_methods = ("recv", "recvfrom", "recv_into", "recvfrom_into", + "send", "sendto") + +class _closedsocket(object): + __slots__ = [] + def _dummy(*args): + raise error(EBADF, 'Bad file descriptor') + # All _delegate_methods must also be initialized here. + send = recv = recv_into = sendto = recvfrom = recvfrom_into = _dummy + __getattr__ = _dummy + +# Wrapper around platform socket objects. This implements +# a platform-independent dup() functionality. The +# implementation currently relies on reference counting +# to close the underlying socket object. +class _socketobject(object): + + __doc__ = _realsocket.__doc__ + + __slots__ = ["_sock", "__weakref__"] + list(_delegate_methods) + + def __init__(self, family=AF_INET, type=SOCK_STREAM, proto=0, _sock=None): + if _sock is None: + _sock = _realsocket(family, type, proto) + self._sock = _sock + for method in _delegate_methods: + setattr(self, method, getattr(_sock, method)) + + def close(self, _closedsocket=_closedsocket, + _delegate_methods=_delegate_methods, setattr=setattr): + # This function should not reference any globals. See issue #808164. + self._sock = _closedsocket() + dummy = self._sock._dummy + for method in _delegate_methods: + setattr(self, method, dummy) + close.__doc__ = _realsocket.close.__doc__ + + def accept(self): + sock, addr = self._sock.accept() + return _socketobject(_sock=sock), addr + accept.__doc__ = _realsocket.accept.__doc__ + + def dup(self): + """dup() -> socket object + + Return a new socket object connected to the same system resource.""" + return _socketobject(_sock=self._sock) + + def makefile(self, mode='r', bufsize=-1): + """makefile([mode[, bufsize]]) -> file object + + Return a regular file object corresponding to the socket. The mode + and bufsize arguments are as for the built-in open() function.""" + return _fileobject(self._sock, mode, bufsize) + + family = property(lambda self: self._sock.family, doc="the socket family") + type = property(lambda self: self._sock.type, doc="the socket type") + proto = property(lambda self: self._sock.proto, doc="the socket protocol") + +def meth(name,self,*args): + return getattr(self._sock,name)(*args) + +for _m in _socketmethods: + p = partial(meth,_m) + p.__name__ = _m + p.__doc__ = getattr(_realsocket,_m).__doc__ + m = MethodType(p,None,_socketobject) + setattr(_socketobject,_m,m) + +socket = SocketType = _socketobject + +class _fileobject(object): + """Faux file object attached to a socket object.""" + + default_bufsize = 8192 + name = "" + + __slots__ = ["mode", "bufsize", "softspace", + # "closed" is a property, see below + "_sock", "_rbufsize", "_wbufsize", "_rbuf", "_wbuf", "_wbuf_len", + "_close"] + + def __init__(self, sock, mode='rb', bufsize=-1, close=False): + self._sock = sock + self.mode = mode # Not actually used in this version + if bufsize < 0: + bufsize = self.default_bufsize + self.bufsize = bufsize + self.softspace = False + # _rbufsize is the suggested recv buffer size. It is *strictly* + # obeyed within readline() for recv calls. If it is larger than + # default_bufsize it will be used for recv calls within read(). + if bufsize == 0: + self._rbufsize = 1 + elif bufsize == 1: + self._rbufsize = self.default_bufsize + else: + self._rbufsize = bufsize + self._wbufsize = bufsize + # We use StringIO for the read buffer to avoid holding a list + # of variously sized string objects which have been known to + # fragment the heap due to how they are malloc()ed and often + # realloc()ed down much smaller than their original allocation. + self._rbuf = StringIO() + self._wbuf = [] # A list of strings + self._wbuf_len = 0 + self._close = close + + def _getclosed(self): + return self._sock is None + closed = property(_getclosed, doc="True if the file is closed") + + def close(self): + try: + if self._sock: + self.flush() + finally: + if self._close: + self._sock.close() + self._sock = None + + def __del__(self): + try: + self.close() + except: + # close() may fail if __init__ didn't complete + pass + + def flush(self): + if self._wbuf: + data = "".join(self._wbuf) + self._wbuf = [] + self._wbuf_len = 0 + buffer_size = max(self._rbufsize, self.default_bufsize) + data_size = len(data) + write_offset = 0 + view = memoryview(data) + try: + while write_offset < data_size: + self._sock.sendall(view[write_offset:write_offset+buffer_size]) + write_offset += buffer_size + finally: + if write_offset < data_size: + remainder = data[write_offset:] + del view, data # explicit free + self._wbuf.append(remainder) + self._wbuf_len = len(remainder) + + def fileno(self): + return self._sock.fileno() + + def write(self, data): + data = str(data) # XXX Should really reject non-string non-buffers + if not data: + return + self._wbuf.append(data) + self._wbuf_len += len(data) + if (self._wbufsize == 0 or + (self._wbufsize == 1 and '\n' in data) or + (self._wbufsize > 1 and self._wbuf_len >= self._wbufsize)): + self.flush() + + def writelines(self, list): + # XXX We could do better here for very long lists + # XXX Should really reject non-string non-buffers + lines = filter(None, map(str, list)) + self._wbuf_len += sum(map(len, lines)) + self._wbuf.extend(lines) + if (self._wbufsize <= 1 or + self._wbuf_len >= self._wbufsize): + self.flush() + + def read(self, size=-1): + # Use max, disallow tiny reads in a loop as they are very inefficient. + # We never leave read() with any leftover data from a new recv() call + # in our internal buffer. + rbufsize = max(self._rbufsize, self.default_bufsize) + # Our use of StringIO rather than lists of string objects returned by + # recv() minimizes memory usage and fragmentation that occurs when + # rbufsize is large compared to the typical return value of recv(). + buf = self._rbuf + buf.seek(0, 2) # seek end + if size < 0: + # Read until EOF + self._rbuf = StringIO() # reset _rbuf. we consume it via buf. + while True: + try: + data = self._sock.recv(rbufsize) + except error, e: + if e.args[0] == EINTR: + continue + raise + if not data: + break + buf.write(data) + return buf.getvalue() + else: + # Read until size bytes or EOF seen, whichever comes first + buf_len = buf.tell() + if buf_len >= size: + # Already have size bytes in our buffer? Extract and return. + buf.seek(0) + rv = buf.read(size) + self._rbuf = StringIO() + self._rbuf.write(buf.read()) + return rv + + self._rbuf = StringIO() # reset _rbuf. we consume it via buf. + while True: + left = size - buf_len + # recv() will malloc the amount of memory given as its + # parameter even though it often returns much less data + # than that. The returned data string is short lived + # as we copy it into a StringIO and free it. This avoids + # fragmentation issues on many platforms. + try: + data = self._sock.recv(left) + except error, e: + if e.args[0] == EINTR: + continue + raise + if not data: + break + n = len(data) + if n == size and not buf_len: + # Shortcut. Avoid buffer data copies when: + # - We have no data in our buffer. + # AND + # - Our call to recv returned exactly the + # number of bytes we were asked to read. + return data + if n == left: + buf.write(data) + del data # explicit free + break + assert n <= left, "recv(%d) returned %d bytes" % (left, n) + buf.write(data) + buf_len += n + del data # explicit free + #assert buf_len == buf.tell() + return buf.getvalue() + + def readline(self, size=-1): + buf = self._rbuf + buf.seek(0, 2) # seek end + if buf.tell() > 0: + # check if we already have it in our buffer + buf.seek(0) + bline = buf.readline(size) + if bline.endswith('\n') or len(bline) == size: + self._rbuf = StringIO() + self._rbuf.write(buf.read()) + return bline + del bline + if size < 0: + # Read until \n or EOF, whichever comes first + if self._rbufsize <= 1: + # Speed up unbuffered case + buf.seek(0) + buffers = [buf.read()] + self._rbuf = StringIO() # reset _rbuf. we consume it via buf. + data = None + recv = self._sock.recv + while True: + try: + while data != "\n": + data = recv(1) + if not data: + break + buffers.append(data) + except error, e: + # The try..except to catch EINTR was moved outside the + # recv loop to avoid the per byte overhead. + if e.args[0] == EINTR: + continue + raise + break + return "".join(buffers) + + buf.seek(0, 2) # seek end + self._rbuf = StringIO() # reset _rbuf. we consume it via buf. + while True: + try: + data = self._sock.recv(self._rbufsize) + except error, e: + if e.args[0] == EINTR: + continue + raise + if not data: + break + nl = data.find('\n') + if nl >= 0: + nl += 1 + buf.write(data[:nl]) + self._rbuf.write(data[nl:]) + del data + break + buf.write(data) + return buf.getvalue() + else: + # Read until size bytes or \n or EOF seen, whichever comes first + buf.seek(0, 2) # seek end + buf_len = buf.tell() + if buf_len >= size: + buf.seek(0) + rv = buf.read(size) + self._rbuf = StringIO() + self._rbuf.write(buf.read()) + return rv + self._rbuf = StringIO() # reset _rbuf. we consume it via buf. + while True: + try: + data = self._sock.recv(self._rbufsize) + except error, e: + if e.args[0] == EINTR: + continue + raise + if not data: + break + left = size - buf_len + # did we just receive a newline? + nl = data.find('\n', 0, left) + if nl >= 0: + nl += 1 + # save the excess data to _rbuf + self._rbuf.write(data[nl:]) + if buf_len: + buf.write(data[:nl]) + break + else: + # Shortcut. Avoid data copy through buf when returning + # a substring of our first recv(). + return data[:nl] + n = len(data) + if n == size and not buf_len: + # Shortcut. Avoid data copy through buf when + # returning exactly all of our first recv(). + return data + if n >= left: + buf.write(data[:left]) + self._rbuf.write(data[left:]) + break + buf.write(data) + buf_len += n + #assert buf_len == buf.tell() + return buf.getvalue() + + def readlines(self, sizehint=0): + total = 0 + list = [] + while True: + line = self.readline() + if not line: + break + list.append(line) + total += len(line) + if sizehint and total >= sizehint: + break + return list + + # Iterator protocols + + def __iter__(self): + return self + + def next(self): + line = self.readline() + if not line: + raise StopIteration + return line + +_GLOBAL_DEFAULT_TIMEOUT = object() + +def create_connection(address, timeout=_GLOBAL_DEFAULT_TIMEOUT, + source_address=None): + """Connect to *address* and return the socket object. + + Convenience function. Connect to *address* (a 2-tuple ``(host, + port)``) and return the socket object. Passing the optional + *timeout* parameter will set the timeout on the socket instance + before attempting to connect. If no *timeout* is supplied, the + global default timeout setting returned by :func:`getdefaulttimeout` + is used. If *source_address* is set it must be a tuple of (host, port) + for the socket to bind as a source address before making the connection. + An host of '' or port 0 tells the OS to use the default. + """ + + host, port = address + err = None + for res in getaddrinfo(host, port, 0, SOCK_STREAM): + af, socktype, proto, canonname, sa = res + sock = None + try: + sock = socket(af, socktype, proto) + if timeout is not _GLOBAL_DEFAULT_TIMEOUT: + sock.settimeout(timeout) + if source_address: + sock.bind(source_address) + sock.connect(sa) + return sock + + except error as _: + err = _ + if sock is not None: + sock.close() + + if err is not None: + raise err + else: + raise error("getaddrinfo returns an empty list") diff --git a/ssl.py b/ssl.py new file mode 100644 index 0000000..f3e5123 --- /dev/null +++ b/ssl.py @@ -0,0 +1,464 @@ +# Wrapper module for _ssl, providing some additional facilities +# implemented in Python. Written by Bill Janssen. + +"""\ +This module provides some more Pythonic support for SSL. + +Object types: + + SSLSocket -- subtype of socket.socket which does SSL over the socket + +Exceptions: + + SSLError -- exception raised for I/O errors + +Functions: + + cert_time_to_seconds -- convert time string used for certificate + notBefore and notAfter functions to integer + seconds past the Epoch (the time values + returned from time.time()) + + fetch_server_certificate (HOST, PORT) -- fetch the certificate provided + by the server running on HOST at port PORT. No + validation of the certificate is performed. + +Integer constants: + +SSL_ERROR_ZERO_RETURN +SSL_ERROR_WANT_READ +SSL_ERROR_WANT_WRITE +SSL_ERROR_WANT_X509_LOOKUP +SSL_ERROR_SYSCALL +SSL_ERROR_SSL +SSL_ERROR_WANT_CONNECT + +SSL_ERROR_EOF +SSL_ERROR_INVALID_ERROR_CODE + +The following group define certificate requirements that one side is +allowing/requiring from the other side: + +CERT_NONE - no certificates from the other side are required (or will + be looked at if provided) +CERT_OPTIONAL - certificates are not required, but if provided will be + validated, and if validation fails, the connection will + also fail +CERT_REQUIRED - certificates are required, and will be validated, and + if validation fails, the connection will also fail + +The following constants identify various SSL protocol variants: + +PROTOCOL_SSLv2 +PROTOCOL_SSLv3 +PROTOCOL_SSLv23 +PROTOCOL_TLSv1 +""" + +import textwrap + +import _ssl # if we can't import it, let the error propagate + +from _ssl import OPENSSL_VERSION_NUMBER, OPENSSL_VERSION_INFO, OPENSSL_VERSION +from _ssl import SSLError +from _ssl import CERT_NONE, CERT_OPTIONAL, CERT_REQUIRED +from _ssl import RAND_status, RAND_egd, RAND_add +from _ssl import \ + SSL_ERROR_ZERO_RETURN, \ + SSL_ERROR_WANT_READ, \ + SSL_ERROR_WANT_WRITE, \ + SSL_ERROR_WANT_X509_LOOKUP, \ + SSL_ERROR_SYSCALL, \ + SSL_ERROR_SSL, \ + SSL_ERROR_WANT_CONNECT, \ + SSL_ERROR_EOF, \ + SSL_ERROR_INVALID_ERROR_CODE +from _ssl import PROTOCOL_SSLv3, PROTOCOL_SSLv23, PROTOCOL_TLSv1 +_PROTOCOL_NAMES = { + PROTOCOL_TLSv1: "TLSv1", + PROTOCOL_SSLv23: "SSLv23", + PROTOCOL_SSLv3: "SSLv3", +} +try: + from _ssl import PROTOCOL_SSLv2 +except ImportError: + pass +else: + _PROTOCOL_NAMES[PROTOCOL_SSLv2] = "SSLv2" + +from socket import socket, _fileobject, _delegate_methods, error as socket_error +from socket import getnameinfo as _getnameinfo +import base64 # for DER-to-PEM translation +import errno + +class SSLSocket(socket): + + """This class implements a subtype of socket.socket that wraps + the underlying OS socket in an SSL context when necessary, and + provides read and write methods over that channel.""" + + def __init__(self, sock, keyfile=None, certfile=None, + server_side=False, cert_reqs=CERT_NONE, + ssl_version=PROTOCOL_SSLv23, ca_certs=None, + do_handshake_on_connect=True, + suppress_ragged_eofs=True, ciphers=None): + socket.__init__(self, _sock=sock._sock) + # The initializer for socket overrides the methods send(), recv(), etc. + # in the instancce, which we don't need -- but we want to provide the + # methods defined in SSLSocket. + for attr in _delegate_methods: + try: + delattr(self, attr) + except AttributeError: + pass + + if certfile and not keyfile: + keyfile = certfile + # see if it's connected + try: + socket.getpeername(self) + except socket_error, e: + if e.errno != errno.ENOTCONN: + raise + # no, no connection yet + self._connected = False + self._sslobj = None + else: + # yes, create the SSL object + self._connected = True + self._sslobj = _ssl.sslwrap(self._sock, server_side, + keyfile, certfile, + cert_reqs, ssl_version, ca_certs, + ciphers) + if do_handshake_on_connect: + self.do_handshake() + self.keyfile = keyfile + self.certfile = certfile + self.cert_reqs = cert_reqs + self.ssl_version = ssl_version + self.ca_certs = ca_certs + self.ciphers = ciphers + self.do_handshake_on_connect = do_handshake_on_connect + self.suppress_ragged_eofs = suppress_ragged_eofs + self._makefile_refs = 0 + + def read(self, len=1024): + + """Read up to LEN bytes and return them. + Return zero-length string on EOF.""" + + try: + return self._sslobj.read(len) + except SSLError, x: + if x.args[0] == SSL_ERROR_EOF and self.suppress_ragged_eofs: + return '' + else: + raise + + def write(self, data): + + """Write DATA to the underlying SSL channel. Returns + number of bytes of DATA actually transmitted.""" + + return self._sslobj.write(data) + + def getpeercert(self, binary_form=False): + + """Returns a formatted version of the data in the + certificate provided by the other end of the SSL channel. + Return None if no certificate was provided, {} if a + certificate was provided, but not validated.""" + + return self._sslobj.peer_certificate(binary_form) + + def cipher(self): + + if not self._sslobj: + return None + else: + return self._sslobj.cipher() + + def send(self, data, flags=0): + if self._sslobj: + if flags != 0: + raise ValueError( + "non-zero flags not allowed in calls to send() on %s" % + self.__class__) + while True: + try: + v = self._sslobj.write(data) + except SSLError, x: + if x.args[0] == SSL_ERROR_WANT_READ: + return 0 + elif x.args[0] == SSL_ERROR_WANT_WRITE: + return 0 + else: + raise + else: + return v + else: + return self._sock.send(data, flags) + + def sendto(self, data, flags_or_addr, addr=None): + if self._sslobj: + raise ValueError("sendto not allowed on instances of %s" % + self.__class__) + elif addr is None: + return self._sock.sendto(data, flags_or_addr) + else: + return self._sock.sendto(data, flags_or_addr, addr) + + def sendall(self, data, flags=0): + if self._sslobj: + if flags != 0: + raise ValueError( + "non-zero flags not allowed in calls to sendall() on %s" % + self.__class__) + amount = len(data) + count = 0 + while (count < amount): + v = self.send(data[count:]) + count += v + return amount + else: + return socket.sendall(self, data, flags) + + def recv(self, buflen=1024, flags=0): + if self._sslobj: + if flags != 0: + raise ValueError( + "non-zero flags not allowed in calls to recv() on %s" % + self.__class__) + return self.read(buflen) + else: + return self._sock.recv(buflen, flags) + + def recv_into(self, buffer, nbytes=None, flags=0): + if buffer and (nbytes is None): + nbytes = len(buffer) + elif nbytes is None: + nbytes = 1024 + if self._sslobj: + if flags != 0: + raise ValueError( + "non-zero flags not allowed in calls to recv_into() on %s" % + self.__class__) + tmp_buffer = self.read(nbytes) + v = len(tmp_buffer) + buffer[:v] = tmp_buffer + return v + else: + return self._sock.recv_into(buffer, nbytes, flags) + + def recvfrom(self, buflen=1024, flags=0): + if self._sslobj: + raise ValueError("recvfrom not allowed on instances of %s" % + self.__class__) + else: + return self._sock.recvfrom(buflen, flags) + + def recvfrom_into(self, buffer, nbytes=None, flags=0): + if self._sslobj: + raise ValueError("recvfrom_into not allowed on instances of %s" % + self.__class__) + else: + return self._sock.recvfrom_into(buffer, nbytes, flags) + + def pending(self): + if self._sslobj: + return self._sslobj.pending() + else: + return 0 + + def unwrap(self): + if self._sslobj: + s = self._sslobj.shutdown() + self._sslobj = None + return s + else: + raise ValueError("No SSL wrapper around " + str(self)) + + def shutdown(self, how): + self._sslobj = None + socket.shutdown(self, how) + + def close(self): + if self._makefile_refs < 1: + self._sslobj = None + socket.close(self) + else: + self._makefile_refs -= 1 + + def do_handshake(self): + + """Perform a TLS/SSL handshake.""" + + self._sslobj.do_handshake() + + def _real_connect(self, addr, return_errno): + # Here we assume that the socket is client-side, and not + # connected at the time of the call. We connect it, then wrap it. + if self._connected: + raise ValueError("attempt to connect already-connected SSLSocket!") + self._sslobj = _ssl.sslwrap(self._sock, False, self.keyfile, self.certfile, + self.cert_reqs, self.ssl_version, + self.ca_certs, self.ciphers) + try: + socket.connect(self, addr) + if self.do_handshake_on_connect: + self.do_handshake() + except socket_error as e: + if return_errno: + return e.errno + else: + self._sslobj = None + raise e + self._connected = True + return 0 + + def connect(self, addr): + """Connects to remote ADDR, and then wraps the connection in + an SSL channel.""" + self._real_connect(addr, False) + + def connect_ex(self, addr): + """Connects to remote ADDR, and then wraps the connection in + an SSL channel.""" + return self._real_connect(addr, True) + + def accept(self): + + """Accepts a new connection from a remote client, and returns + a tuple containing that new connection wrapped with a server-side + SSL channel, and the address of the remote client.""" + + newsock, addr = socket.accept(self) + return (SSLSocket(newsock, + keyfile=self.keyfile, + certfile=self.certfile, + server_side=True, + cert_reqs=self.cert_reqs, + ssl_version=self.ssl_version, + ca_certs=self.ca_certs, + ciphers=self.ciphers, + do_handshake_on_connect=self.do_handshake_on_connect, + suppress_ragged_eofs=self.suppress_ragged_eofs), + addr) + + def makefile(self, mode='r', bufsize=-1): + + """Make and return a file-like object that + works with the SSL connection. Just use the code + from the socket module.""" + + self._makefile_refs += 1 + # close=True so as to decrement the reference count when done with + # the file-like object. + return _fileobject(self, mode, bufsize, close=True) + + + +def wrap_socket(sock, keyfile=None, certfile=None, + server_side=False, cert_reqs=CERT_NONE, + ssl_version=PROTOCOL_SSLv23, ca_certs=None, + do_handshake_on_connect=True, + suppress_ragged_eofs=True, ciphers=None): + + return SSLSocket(sock, keyfile=keyfile, certfile=certfile, + server_side=server_side, cert_reqs=cert_reqs, + ssl_version=ssl_version, ca_certs=ca_certs, + do_handshake_on_connect=do_handshake_on_connect, + suppress_ragged_eofs=suppress_ragged_eofs, + ciphers=ciphers) + + +# some utility functions + +def cert_time_to_seconds(cert_time): + + """Takes a date-time string in standard ASN1_print form + ("MON DAY 24HOUR:MINUTE:SEC YEAR TIMEZONE") and return + a Python time value in seconds past the epoch.""" + + import time + return time.mktime(time.strptime(cert_time, "%b %d %H:%M:%S %Y GMT")) + +PEM_HEADER = "-----BEGIN CERTIFICATE-----" +PEM_FOOTER = "-----END CERTIFICATE-----" + +def DER_cert_to_PEM_cert(der_cert_bytes): + + """Takes a certificate in binary DER format and returns the + PEM version of it as a string.""" + + if hasattr(base64, 'standard_b64encode'): + # preferred because older API gets line-length wrong + f = base64.standard_b64encode(der_cert_bytes) + return (PEM_HEADER + '\n' + + textwrap.fill(f, 64) + '\n' + + PEM_FOOTER + '\n') + else: + return (PEM_HEADER + '\n' + + base64.encodestring(der_cert_bytes) + + PEM_FOOTER + '\n') + +def PEM_cert_to_DER_cert(pem_cert_string): + + """Takes a certificate in ASCII PEM format and returns the + DER-encoded version of it as a byte sequence""" + + if not pem_cert_string.startswith(PEM_HEADER): + raise ValueError("Invalid PEM encoding; must start with %s" + % PEM_HEADER) + if not pem_cert_string.strip().endswith(PEM_FOOTER): + raise ValueError("Invalid PEM encoding; must end with %s" + % PEM_FOOTER) + d = pem_cert_string.strip()[len(PEM_HEADER):-len(PEM_FOOTER)] + return base64.decodestring(d) + +def get_server_certificate(addr, ssl_version=PROTOCOL_SSLv3, ca_certs=None): + + """Retrieve the certificate from the server at the specified address, + and return it as a PEM-encoded string. + If 'ca_certs' is specified, validate the server cert against it. + If 'ssl_version' is specified, use it in the connection attempt.""" + + host, port = addr + if (ca_certs is not None): + cert_reqs = CERT_REQUIRED + else: + cert_reqs = CERT_NONE + s = wrap_socket(socket(), ssl_version=ssl_version, + cert_reqs=cert_reqs, ca_certs=ca_certs) + s.connect(addr) + dercert = s.getpeercert(True) + s.close() + return DER_cert_to_PEM_cert(dercert) + +def get_protocol_name(protocol_code): + return _PROTOCOL_NAMES.get(protocol_code, '') + + +# a replacement for the old socket.ssl function + +def sslwrap_simple(sock, keyfile=None, certfile=None): + + """A replacement for the old socket.ssl function. Designed + for compability with Python 2.5 and earlier. Will disappear in + Python 3.0.""" + + if hasattr(sock, "_sock"): + sock = sock._sock + + ssl_sock = _ssl.sslwrap(sock, 0, keyfile, certfile, CERT_NONE, + PROTOCOL_SSLv23, None) + try: + sock.getpeername() + except socket_error: + # no, no connection yet + pass + else: + # yes, do the handshake + ssl_sock.do_handshake() + + return ssl_sock diff --git a/stat.py b/stat.py new file mode 100644 index 0000000..abed5c9 --- /dev/null +++ b/stat.py @@ -0,0 +1,96 @@ +"""Constants/functions for interpreting results of os.stat() and os.lstat(). + +Suggested usage: from stat import * +""" + +# Indices for stat struct members in the tuple returned by os.stat() + +ST_MODE = 0 +ST_INO = 1 +ST_DEV = 2 +ST_NLINK = 3 +ST_UID = 4 +ST_GID = 5 +ST_SIZE = 6 +ST_ATIME = 7 +ST_MTIME = 8 +ST_CTIME = 9 + +# Extract bits from the mode + +def S_IMODE(mode): + return mode & 07777 + +def S_IFMT(mode): + return mode & 0170000 + +# Constants used as S_IFMT() for various file types +# (not all are implemented on all systems) + +S_IFDIR = 0040000 +S_IFCHR = 0020000 +S_IFBLK = 0060000 +S_IFREG = 0100000 +S_IFIFO = 0010000 +S_IFLNK = 0120000 +S_IFSOCK = 0140000 + +# Functions to test for each file type + +def S_ISDIR(mode): + return S_IFMT(mode) == S_IFDIR + +def S_ISCHR(mode): + return S_IFMT(mode) == S_IFCHR + +def S_ISBLK(mode): + return S_IFMT(mode) == S_IFBLK + +def S_ISREG(mode): + return S_IFMT(mode) == S_IFREG + +def S_ISFIFO(mode): + return S_IFMT(mode) == S_IFIFO + +def S_ISLNK(mode): + return S_IFMT(mode) == S_IFLNK + +def S_ISSOCK(mode): + return S_IFMT(mode) == S_IFSOCK + +# Names for permission bits + +S_ISUID = 04000 +S_ISGID = 02000 +S_ENFMT = S_ISGID +S_ISVTX = 01000 +S_IREAD = 00400 +S_IWRITE = 00200 +S_IEXEC = 00100 +S_IRWXU = 00700 +S_IRUSR = 00400 +S_IWUSR = 00200 +S_IXUSR = 00100 +S_IRWXG = 00070 +S_IRGRP = 00040 +S_IWGRP = 00020 +S_IXGRP = 00010 +S_IRWXO = 00007 +S_IROTH = 00004 +S_IWOTH = 00002 +S_IXOTH = 00001 + +# Names for file flags + +UF_NODUMP = 0x00000001 +UF_IMMUTABLE = 0x00000002 +UF_APPEND = 0x00000004 +UF_OPAQUE = 0x00000008 +UF_NOUNLINK = 0x00000010 +UF_COMPRESSED = 0x00000020 # OS X: file is hfs-compressed +UF_HIDDEN = 0x00008000 # OS X: file should not be displayed +SF_ARCHIVED = 0x00010000 +SF_IMMUTABLE = 0x00020000 +SF_APPEND = 0x00040000 +SF_NOUNLINK = 0x00100000 +SF_SNAPSHOT = 0x00200000 diff --git a/string.py b/string.py new file mode 100644 index 0000000..23608b4 --- /dev/null +++ b/string.py @@ -0,0 +1,656 @@ +"""A collection of string operations (most are no longer used). + +Warning: most of the code you see here isn't normally used nowadays. +Beginning with Python 1.6, many of these functions are implemented as +methods on the standard string object. They used to be implemented by +a built-in module called strop, but strop is now obsolete itself. + +Public module variables: + +whitespace -- a string containing all characters considered whitespace +lowercase -- a string containing all characters considered lowercase letters +uppercase -- a string containing all characters considered uppercase letters +letters -- a string containing all characters considered letters +digits -- a string containing all characters considered decimal digits +hexdigits -- a string containing all characters considered hexadecimal digits +octdigits -- a string containing all characters considered octal digits +punctuation -- a string containing all characters considered punctuation +printable -- a string containing all characters considered printable + +""" + +# Some strings for ctype-style character classification +whitespace = ' \t\n\r\v\f' +lowercase = 'abcdefghijklmnopqrstuvwxyz' +uppercase = 'ABCDEFGHIJKLMNOPQRSTUVWXYZ' +letters = lowercase + uppercase +ascii_lowercase = lowercase +ascii_uppercase = uppercase +ascii_letters = ascii_lowercase + ascii_uppercase +digits = '0123456789' +hexdigits = digits + 'abcdef' + 'ABCDEF' +octdigits = '01234567' +punctuation = """!"#$%&'()*+,-./:;<=>?@[\]^_`{|}~""" +printable = digits + letters + punctuation + whitespace + +# Case conversion helpers +# Use str to convert Unicode literal in case of -U +l = map(chr, xrange(256)) +_idmap = str('').join(l) +del l + +# Functions which aren't available as string methods. + +# Capitalize the words in a string, e.g. " aBc dEf " -> "Abc Def". +def capwords(s, sep=None): + """capwords(s [,sep]) -> string + + Split the argument into words using split, capitalize each + word using capitalize, and join the capitalized words using + join. If the optional second argument sep is absent or None, + runs of whitespace characters are replaced by a single space + and leading and trailing whitespace are removed, otherwise + sep is used to split and join the words. + + """ + return (sep or ' ').join(x.capitalize() for x in s.split(sep)) + + +# Construct a translation string +_idmapL = None +def maketrans(fromstr, tostr): + """maketrans(frm, to) -> string + + Return a translation table (a string of 256 bytes long) + suitable for use in string.translate. The strings frm and to + must be of the same length. + + """ + if len(fromstr) != len(tostr): + raise ValueError, "maketrans arguments must have same length" + global _idmapL + if not _idmapL: + _idmapL = list(_idmap) + L = _idmapL[:] + fromstr = map(ord, fromstr) + for i in range(len(fromstr)): + L[fromstr[i]] = tostr[i] + return ''.join(L) + + + +#################################################################### +import re as _re + +class _multimap: + """Helper class for combining multiple mappings. + + Used by .{safe_,}substitute() to combine the mapping and keyword + arguments. + """ + def __init__(self, primary, secondary): + self._primary = primary + self._secondary = secondary + + def __getitem__(self, key): + try: + return self._primary[key] + except KeyError: + return self._secondary[key] + + +class _TemplateMetaclass(type): + pattern = r""" + %(delim)s(?: + (?P%(delim)s) | # Escape sequence of two delimiters + (?P%(id)s) | # delimiter and a Python identifier + {(?P%(id)s)} | # delimiter and a braced identifier + (?P) # Other ill-formed delimiter exprs + ) + """ + + def __init__(cls, name, bases, dct): + super(_TemplateMetaclass, cls).__init__(name, bases, dct) + if 'pattern' in dct: + pattern = cls.pattern + else: + pattern = _TemplateMetaclass.pattern % { + 'delim' : _re.escape(cls.delimiter), + 'id' : cls.idpattern, + } + cls.pattern = _re.compile(pattern, _re.IGNORECASE | _re.VERBOSE) + + +class Template: + """A string class for supporting $-substitutions.""" + __metaclass__ = _TemplateMetaclass + + delimiter = '$' + idpattern = r'[_a-z][_a-z0-9]*' + + def __init__(self, template): + self.template = template + + # Search for $$, $identifier, ${identifier}, and any bare $'s + + def _invalid(self, mo): + i = mo.start('invalid') + lines = self.template[:i].splitlines(True) + if not lines: + colno = 1 + lineno = 1 + else: + colno = i - len(''.join(lines[:-1])) + lineno = len(lines) + raise ValueError('Invalid placeholder in string: line %d, col %d' % + (lineno, colno)) + + def substitute(*args, **kws): + if not args: + raise TypeError("descriptor 'substitute' of 'Template' object " + "needs an argument") + self, args = args[0], args[1:] # allow the "self" keyword be passed + if len(args) > 1: + raise TypeError('Too many positional arguments') + if not args: + mapping = kws + elif kws: + mapping = _multimap(kws, args[0]) + else: + mapping = args[0] + # Helper function for .sub() + def convert(mo): + # Check the most common path first. + named = mo.group('named') or mo.group('braced') + if named is not None: + val = mapping[named] + # We use this idiom instead of str() because the latter will + # fail if val is a Unicode containing non-ASCII characters. + return '%s' % (val,) + if mo.group('escaped') is not None: + return self.delimiter + if mo.group('invalid') is not None: + self._invalid(mo) + raise ValueError('Unrecognized named group in pattern', + self.pattern) + return self.pattern.sub(convert, self.template) + + def safe_substitute(*args, **kws): + if not args: + raise TypeError("descriptor 'safe_substitute' of 'Template' object " + "needs an argument") + self, args = args[0], args[1:] # allow the "self" keyword be passed + if len(args) > 1: + raise TypeError('Too many positional arguments') + if not args: + mapping = kws + elif kws: + mapping = _multimap(kws, args[0]) + else: + mapping = args[0] + # Helper function for .sub() + def convert(mo): + named = mo.group('named') or mo.group('braced') + if named is not None: + try: + # We use this idiom instead of str() because the latter + # will fail if val is a Unicode containing non-ASCII + return '%s' % (mapping[named],) + except KeyError: + return mo.group() + if mo.group('escaped') is not None: + return self.delimiter + if mo.group('invalid') is not None: + return mo.group() + raise ValueError('Unrecognized named group in pattern', + self.pattern) + return self.pattern.sub(convert, self.template) + + + +#################################################################### +# NOTE: Everything below here is deprecated. Use string methods instead. +# This stuff will go away in Python 3.0. + +# Backward compatible names for exceptions +index_error = ValueError +atoi_error = ValueError +atof_error = ValueError +atol_error = ValueError + +# convert UPPER CASE letters to lower case +def lower(s): + """lower(s) -> string + + Return a copy of the string s converted to lowercase. + + """ + return s.lower() + +# Convert lower case letters to UPPER CASE +def upper(s): + """upper(s) -> string + + Return a copy of the string s converted to uppercase. + + """ + return s.upper() + +# Swap lower case letters and UPPER CASE +def swapcase(s): + """swapcase(s) -> string + + Return a copy of the string s with upper case characters + converted to lowercase and vice versa. + + """ + return s.swapcase() + +# Strip leading and trailing tabs and spaces +def strip(s, chars=None): + """strip(s [,chars]) -> string + + Return a copy of the string s with leading and trailing + whitespace removed. + If chars is given and not None, remove characters in chars instead. + If chars is unicode, S will be converted to unicode before stripping. + + """ + return s.strip(chars) + +# Strip leading tabs and spaces +def lstrip(s, chars=None): + """lstrip(s [,chars]) -> string + + Return a copy of the string s with leading whitespace removed. + If chars is given and not None, remove characters in chars instead. + + """ + return s.lstrip(chars) + +# Strip trailing tabs and spaces +def rstrip(s, chars=None): + """rstrip(s [,chars]) -> string + + Return a copy of the string s with trailing whitespace removed. + If chars is given and not None, remove characters in chars instead. + + """ + return s.rstrip(chars) + + +# Split a string into a list of space/tab-separated words +def split(s, sep=None, maxsplit=-1): + """split(s [,sep [,maxsplit]]) -> list of strings + + Return a list of the words in the string s, using sep as the + delimiter string. If maxsplit is given, splits at no more than + maxsplit places (resulting in at most maxsplit+1 words). If sep + is not specified or is None, any whitespace string is a separator. + + (split and splitfields are synonymous) + + """ + return s.split(sep, maxsplit) +splitfields = split + +# Split a string into a list of space/tab-separated words +def rsplit(s, sep=None, maxsplit=-1): + """rsplit(s [,sep [,maxsplit]]) -> list of strings + + Return a list of the words in the string s, using sep as the + delimiter string, starting at the end of the string and working + to the front. If maxsplit is given, at most maxsplit splits are + done. If sep is not specified or is None, any whitespace string + is a separator. + """ + return s.rsplit(sep, maxsplit) + +# Join fields with optional separator +def join(words, sep = ' '): + """join(list [,sep]) -> string + + Return a string composed of the words in list, with + intervening occurrences of sep. The default separator is a + single space. + + (joinfields and join are synonymous) + + """ + return sep.join(words) +joinfields = join + +# Find substring, raise exception if not found +def index(s, *args): + """index(s, sub [,start [,end]]) -> int + + Like find but raises ValueError when the substring is not found. + + """ + return s.index(*args) + +# Find last substring, raise exception if not found +def rindex(s, *args): + """rindex(s, sub [,start [,end]]) -> int + + Like rfind but raises ValueError when the substring is not found. + + """ + return s.rindex(*args) + +# Count non-overlapping occurrences of substring +def count(s, *args): + """count(s, sub[, start[,end]]) -> int + + Return the number of occurrences of substring sub in string + s[start:end]. Optional arguments start and end are + interpreted as in slice notation. + + """ + return s.count(*args) + +# Find substring, return -1 if not found +def find(s, *args): + """find(s, sub [,start [,end]]) -> in + + Return the lowest index in s where substring sub is found, + such that sub is contained within s[start,end]. Optional + arguments start and end are interpreted as in slice notation. + + Return -1 on failure. + + """ + return s.find(*args) + +# Find last substring, return -1 if not found +def rfind(s, *args): + """rfind(s, sub [,start [,end]]) -> int + + Return the highest index in s where substring sub is found, + such that sub is contained within s[start,end]. Optional + arguments start and end are interpreted as in slice notation. + + Return -1 on failure. + + """ + return s.rfind(*args) + +# for a bit of speed +_float = float +_int = int +_long = long + +# Convert string to float +def atof(s): + """atof(s) -> float + + Return the floating point number represented by the string s. + + """ + return _float(s) + + +# Convert string to integer +def atoi(s , base=10): + """atoi(s [,base]) -> int + + Return the integer represented by the string s in the given + base, which defaults to 10. The string s must consist of one + or more digits, possibly preceded by a sign. If base is 0, it + is chosen from the leading characters of s, 0 for octal, 0x or + 0X for hexadecimal. If base is 16, a preceding 0x or 0X is + accepted. + + """ + return _int(s, base) + + +# Convert string to long integer +def atol(s, base=10): + """atol(s [,base]) -> long + + Return the long integer represented by the string s in the + given base, which defaults to 10. The string s must consist + of one or more digits, possibly preceded by a sign. If base + is 0, it is chosen from the leading characters of s, 0 for + octal, 0x or 0X for hexadecimal. If base is 16, a preceding + 0x or 0X is accepted. A trailing L or l is not accepted, + unless base is 0. + + """ + return _long(s, base) + + +# Left-justify a string +def ljust(s, width, *args): + """ljust(s, width[, fillchar]) -> string + + Return a left-justified version of s, in a field of the + specified width, padded with spaces as needed. The string is + never truncated. If specified the fillchar is used instead of spaces. + + """ + return s.ljust(width, *args) + +# Right-justify a string +def rjust(s, width, *args): + """rjust(s, width[, fillchar]) -> string + + Return a right-justified version of s, in a field of the + specified width, padded with spaces as needed. The string is + never truncated. If specified the fillchar is used instead of spaces. + + """ + return s.rjust(width, *args) + +# Center a string +def center(s, width, *args): + """center(s, width[, fillchar]) -> string + + Return a center version of s, in a field of the specified + width. padded with spaces as needed. The string is never + truncated. If specified the fillchar is used instead of spaces. + + """ + return s.center(width, *args) + +# Zero-fill a number, e.g., (12, 3) --> '012' and (-3, 3) --> '-03' +# Decadent feature: the argument may be a string or a number +# (Use of this is deprecated; it should be a string as with ljust c.s.) +def zfill(x, width): + """zfill(x, width) -> string + + Pad a numeric string x with zeros on the left, to fill a field + of the specified width. The string x is never truncated. + + """ + if not isinstance(x, basestring): + x = repr(x) + return x.zfill(width) + +# Expand tabs in a string. +# Doesn't take non-printing chars into account, but does understand \n. +def expandtabs(s, tabsize=8): + """expandtabs(s [,tabsize]) -> string + + Return a copy of the string s with all tab characters replaced + by the appropriate number of spaces, depending on the current + column, and the tabsize (default 8). + + """ + return s.expandtabs(tabsize) + +# Character translation through look-up table. +def translate(s, table, deletions=""): + """translate(s,table [,deletions]) -> string + + Return a copy of the string s, where all characters occurring + in the optional argument deletions are removed, and the + remaining characters have been mapped through the given + translation table, which must be a string of length 256. The + deletions argument is not allowed for Unicode strings. + + """ + if deletions or table is None: + return s.translate(table, deletions) + else: + # Add s[:0] so that if s is Unicode and table is an 8-bit string, + # table is converted to Unicode. This means that table *cannot* + # be a dictionary -- for that feature, use u.translate() directly. + return s.translate(table + s[:0]) + +# Capitalize a string, e.g. "aBc dEf" -> "Abc def". +def capitalize(s): + """capitalize(s) -> string + + Return a copy of the string s with only its first character + capitalized. + + """ + return s.capitalize() + +# Substring replacement (global) +def replace(s, old, new, maxreplace=-1): + """replace (str, old, new[, maxreplace]) -> string + + Return a copy of string str with all occurrences of substring + old replaced by new. If the optional argument maxreplace is + given, only the first maxreplace occurrences are replaced. + + """ + return s.replace(old, new, maxreplace) + + +# Try importing optional built-in module "strop" -- if it exists, +# it redefines some string operations that are 100-1000 times faster. +# It also defines values for whitespace, lowercase and uppercase +# that match 's definitions. + +try: + from strop import maketrans, lowercase, uppercase, whitespace + letters = lowercase + uppercase +except ImportError: + pass # Use the original versions + +######################################################################## +# the Formatter class +# see PEP 3101 for details and purpose of this class + +# The hard parts are reused from the C implementation. They're exposed as "_" +# prefixed methods of str and unicode. + +# The overall parser is implemented in str._formatter_parser. +# The field name parser is implemented in str._formatter_field_name_split + +class Formatter(object): + def format(*args, **kwargs): + if not args: + raise TypeError("descriptor 'format' of 'Formatter' object " + "needs an argument") + self, args = args[0], args[1:] # allow the "self" keyword be passed + try: + format_string, args = args[0], args[1:] # allow the "format_string" keyword be passed + except IndexError: + if 'format_string' in kwargs: + format_string = kwargs.pop('format_string') + else: + raise TypeError("format() missing 1 required positional " + "argument: 'format_string'") + return self.vformat(format_string, args, kwargs) + + def vformat(self, format_string, args, kwargs): + used_args = set() + result = self._vformat(format_string, args, kwargs, used_args, 2) + self.check_unused_args(used_args, args, kwargs) + return result + + def _vformat(self, format_string, args, kwargs, used_args, recursion_depth): + if recursion_depth < 0: + raise ValueError('Max string recursion exceeded') + result = [] + for literal_text, field_name, format_spec, conversion in \ + self.parse(format_string): + + # output the literal text + if literal_text: + result.append(literal_text) + + # if there's a field, output it + if field_name is not None: + # this is some markup, find the object and do + # the formatting + + # given the field_name, find the object it references + # and the argument it came from + obj, arg_used = self.get_field(field_name, args, kwargs) + used_args.add(arg_used) + + # do any conversion on the resulting object + obj = self.convert_field(obj, conversion) + + # expand the format spec, if needed + format_spec = self._vformat(format_spec, args, kwargs, + used_args, recursion_depth-1) + + # format the object and append to the result + result.append(self.format_field(obj, format_spec)) + + return ''.join(result) + + + def get_value(self, key, args, kwargs): + if isinstance(key, (int, long)): + return args[key] + else: + return kwargs[key] + + + def check_unused_args(self, used_args, args, kwargs): + pass + + + def format_field(self, value, format_spec): + return format(value, format_spec) + + + def convert_field(self, value, conversion): + # do any conversion on the resulting object + if conversion is None: + return value + elif conversion == 's': + return str(value) + elif conversion == 'r': + return repr(value) + raise ValueError("Unknown conversion specifier {0!s}".format(conversion)) + + + # returns an iterable that contains tuples of the form: + # (literal_text, field_name, format_spec, conversion) + # literal_text can be zero length + # field_name can be None, in which case there's no + # object to format and output + # if field_name is not None, it is looked up, formatted + # with format_spec and conversion and then used + def parse(self, format_string): + return format_string._formatter_parser() + + + # given a field_name, find the object it references. + # field_name: the field being looked up, e.g. "0.name" + # or "lookup[3]" + # used_args: a set of which args have been used + # args, kwargs: as passed in to vformat + def get_field(self, field_name, args, kwargs): + first, rest = field_name._formatter_field_name_split() + + obj = self.get_value(first, args, kwargs) + + # loop through the rest of the field_name, doing + # getattr or getitem as needed + for is_attr, i in rest: + if is_attr: + obj = getattr(obj, i) + else: + obj = obj[i] + + return obj, first diff --git a/struct.py b/struct.py new file mode 100644 index 0000000..b022355 --- /dev/null +++ b/struct.py @@ -0,0 +1,3 @@ +from _struct import * +from _struct import _clearcache +from _struct import __doc__ diff --git a/tempfile.py b/tempfile.py new file mode 100644 index 0000000..184dfc1 --- /dev/null +++ b/tempfile.py @@ -0,0 +1,639 @@ +"""Temporary files. + +This module provides generic, low- and high-level interfaces for +creating temporary files and directories. All of the interfaces +provided by this module can be used without fear of race conditions +except for 'mktemp'. 'mktemp' is subject to race conditions and +should not be used; it is provided for backward compatibility only. + +This module also provides some data items to the user: + + TMP_MAX - maximum number of names that will be tried before + giving up. + template - the default prefix for all temporary names. + You may change this to control the default prefix. + tempdir - If this is set to a string before the first use of + any routine from this module, it will be considered as + another candidate location to store temporary files. +""" + +__all__ = [ + "NamedTemporaryFile", "TemporaryFile", # high level safe interfaces + "SpooledTemporaryFile", + "mkstemp", "mkdtemp", # low level safe interfaces + "mktemp", # deprecated unsafe interface + "TMP_MAX", "gettempprefix", # constants + "tempdir", "gettempdir" + ] + + +# Imports. + +import io as _io +import os as _os +import errno as _errno +from random import Random as _Random + +try: + from cStringIO import StringIO as _StringIO +except ImportError: + from StringIO import StringIO as _StringIO + +try: + import fcntl as _fcntl +except ImportError: + def _set_cloexec(fd): + pass +else: + def _set_cloexec(fd): + try: + flags = _fcntl.fcntl(fd, _fcntl.F_GETFD, 0) + except IOError: + pass + else: + # flags read successfully, modify + flags |= _fcntl.FD_CLOEXEC + _fcntl.fcntl(fd, _fcntl.F_SETFD, flags) + + +try: + import thread as _thread +except ImportError: + import dummy_thread as _thread +_allocate_lock = _thread.allocate_lock + +_text_openflags = _os.O_RDWR | _os.O_CREAT | _os.O_EXCL +if hasattr(_os, 'O_NOINHERIT'): + _text_openflags |= _os.O_NOINHERIT +if hasattr(_os, 'O_NOFOLLOW'): + _text_openflags |= _os.O_NOFOLLOW + +_bin_openflags = _text_openflags +if hasattr(_os, 'O_BINARY'): + _bin_openflags |= _os.O_BINARY + +if hasattr(_os, 'TMP_MAX'): + TMP_MAX = _os.TMP_MAX +else: + TMP_MAX = 10000 + +template = "tmp" + +# Internal routines. + +_once_lock = _allocate_lock() + +if hasattr(_os, "lstat"): + _stat = _os.lstat +elif hasattr(_os, "stat"): + _stat = _os.stat +else: + # Fallback. All we need is something that raises os.error if the + # file doesn't exist. + def _stat(fn): + try: + f = open(fn) + except IOError: + raise _os.error + f.close() + +def _exists(fn): + try: + _stat(fn) + except _os.error: + return False + else: + return True + +class _RandomNameSequence: + """An instance of _RandomNameSequence generates an endless + sequence of unpredictable strings which can safely be incorporated + into file names. Each string is six characters long. Multiple + threads can safely use the same instance at the same time. + + _RandomNameSequence is an iterator.""" + + characters = ("abcdefghijklmnopqrstuvwxyz" + + "ABCDEFGHIJKLMNOPQRSTUVWXYZ" + + "0123456789_") + + def __init__(self): + self.mutex = _allocate_lock() + self.normcase = _os.path.normcase + + @property + def rng(self): + cur_pid = _os.getpid() + if cur_pid != getattr(self, '_rng_pid', None): + self._rng = _Random() + self._rng_pid = cur_pid + return self._rng + + def __iter__(self): + return self + + def next(self): + m = self.mutex + c = self.characters + choose = self.rng.choice + + m.acquire() + try: + letters = [choose(c) for dummy in "123456"] + finally: + m.release() + + return self.normcase(''.join(letters)) + +def _candidate_tempdir_list(): + """Generate a list of candidate temporary directories which + _get_default_tempdir will try.""" + + dirlist = [] + + # First, try the environment. + for envname in 'TMPDIR', 'TEMP', 'TMP': + dirname = _os.getenv(envname) + if dirname: dirlist.append(dirname) + + # Failing that, try OS-specific locations. + if _os.name == 'riscos': + dirname = _os.getenv('Wimp$ScrapDir') + if dirname: dirlist.append(dirname) + elif _os.name == 'nt': + dirlist.extend([ r'c:\temp', r'c:\tmp', r'\temp', r'\tmp' ]) + else: + dirlist.extend([ '/tmp', '/var/tmp', '/usr/tmp' ]) + + # As a last resort, the current directory. + try: + dirlist.append(_os.getcwd()) + except (AttributeError, _os.error): + dirlist.append(_os.curdir) + + return dirlist + +def _get_default_tempdir(): + """Calculate the default directory to use for temporary files. + This routine should be called exactly once. + + We determine whether or not a candidate temp dir is usable by + trying to create and write to a file in that directory. If this + is successful, the test file is deleted. To prevent denial of + service, the name of the test file must be randomized.""" + + namer = _RandomNameSequence() + dirlist = _candidate_tempdir_list() + flags = _text_openflags + + for dir in dirlist: + if dir != _os.curdir: + dir = _os.path.normcase(_os.path.abspath(dir)) + # Try only a few names per directory. + for seq in xrange(100): + name = namer.next() + filename = _os.path.join(dir, name) + try: + fd = _os.open(filename, flags, 0o600) + try: + try: + with _io.open(fd, 'wb', closefd=False) as fp: + fp.write(b'blat') + finally: + _os.close(fd) + finally: + _os.unlink(filename) + return dir + except (OSError, IOError) as e: + if e.args[0] == _errno.EEXIST: + continue + if (_os.name == 'nt' and e.args[0] == _errno.EACCES and + _os.path.isdir(dir) and _os.access(dir, _os.W_OK)): + # On windows, when a directory with the chosen name already + # exists, EACCES error code is returned instead of EEXIST. + continue + break # no point trying more names in this directory + raise IOError, (_errno.ENOENT, + ("No usable temporary directory found in %s" % dirlist)) + +_name_sequence = None + +def _get_candidate_names(): + """Common setup sequence for all user-callable interfaces.""" + + global _name_sequence + if _name_sequence is None: + _once_lock.acquire() + try: + if _name_sequence is None: + _name_sequence = _RandomNameSequence() + finally: + _once_lock.release() + return _name_sequence + + +def _mkstemp_inner(dir, pre, suf, flags): + """Code common to mkstemp, TemporaryFile, and NamedTemporaryFile.""" + + names = _get_candidate_names() + + for seq in xrange(TMP_MAX): + name = names.next() + file = _os.path.join(dir, pre + name + suf) + try: + fd = _os.open(file, flags, 0600) + _set_cloexec(fd) + return (fd, _os.path.abspath(file)) + except OSError, e: + if e.errno == _errno.EEXIST: + continue # try again + if (_os.name == 'nt' and e.errno == _errno.EACCES and + _os.path.isdir(dir) and _os.access(dir, _os.W_OK)): + # On windows, when a directory with the chosen name already + # exists, EACCES error code is returned instead of EEXIST. + continue + raise + + raise IOError, (_errno.EEXIST, "No usable temporary file name found") + + +# User visible interfaces. + +def gettempprefix(): + """Accessor for tempdir.template.""" + return template + +tempdir = None + +def gettempdir(): + """Accessor for tempfile.tempdir.""" + global tempdir + if tempdir is None: + _once_lock.acquire() + try: + if tempdir is None: + tempdir = _get_default_tempdir() + finally: + _once_lock.release() + return tempdir + +def mkstemp(suffix="", prefix=template, dir=None, text=False): + """User-callable function to create and return a unique temporary + file. The return value is a pair (fd, name) where fd is the + file descriptor returned by os.open, and name is the filename. + + If 'suffix' is specified, the file name will end with that suffix, + otherwise there will be no suffix. + + If 'prefix' is specified, the file name will begin with that prefix, + otherwise a default prefix is used. + + If 'dir' is specified, the file will be created in that directory, + otherwise a default directory is used. + + If 'text' is specified and true, the file is opened in text + mode. Else (the default) the file is opened in binary mode. On + some operating systems, this makes no difference. + + The file is readable and writable only by the creating user ID. + If the operating system uses permission bits to indicate whether a + file is executable, the file is executable by no one. The file + descriptor is not inherited by children of this process. + + Caller is responsible for deleting the file when done with it. + """ + + if dir is None: + dir = gettempdir() + + if text: + flags = _text_openflags + else: + flags = _bin_openflags + + return _mkstemp_inner(dir, prefix, suffix, flags) + + +def mkdtemp(suffix="", prefix=template, dir=None): + """User-callable function to create and return a unique temporary + directory. The return value is the pathname of the directory. + + Arguments are as for mkstemp, except that the 'text' argument is + not accepted. + + The directory is readable, writable, and searchable only by the + creating user. + + Caller is responsible for deleting the directory when done with it. + """ + + if dir is None: + dir = gettempdir() + + names = _get_candidate_names() + + for seq in xrange(TMP_MAX): + name = names.next() + file = _os.path.join(dir, prefix + name + suffix) + try: + _os.mkdir(file, 0700) + return file + except OSError, e: + if e.errno == _errno.EEXIST: + continue # try again + if (_os.name == 'nt' and e.errno == _errno.EACCES and + _os.path.isdir(dir) and _os.access(dir, _os.W_OK)): + # On windows, when a directory with the chosen name already + # exists, EACCES error code is returned instead of EEXIST. + continue + raise + + raise IOError, (_errno.EEXIST, "No usable temporary directory name found") + +def mktemp(suffix="", prefix=template, dir=None): + """User-callable function to return a unique temporary file name. The + file is not created. + + Arguments are as for mkstemp, except that the 'text' argument is + not accepted. + + This function is unsafe and should not be used. The file name + refers to a file that did not exist at some point, but by the time + you get around to creating it, someone else may have beaten you to + the punch. + """ + +## from warnings import warn as _warn +## _warn("mktemp is a potential security risk to your program", +## RuntimeWarning, stacklevel=2) + + if dir is None: + dir = gettempdir() + + names = _get_candidate_names() + for seq in xrange(TMP_MAX): + name = names.next() + file = _os.path.join(dir, prefix + name + suffix) + if not _exists(file): + return file + + raise IOError, (_errno.EEXIST, "No usable temporary filename found") + + +class _TemporaryFileWrapper: + """Temporary file wrapper + + This class provides a wrapper around files opened for + temporary use. In particular, it seeks to automatically + remove the file when it is no longer needed. + """ + + def __init__(self, file, name, delete=True): + self.file = file + self.name = name + self.close_called = False + self.delete = delete + + def __getattr__(self, name): + # Attribute lookups are delegated to the underlying file + # and cached for non-numeric results + # (i.e. methods are cached, closed and friends are not) + file = self.__dict__['file'] + a = getattr(file, name) + if not issubclass(type(a), type(0)): + setattr(self, name, a) + return a + + # The underlying __enter__ method returns the wrong object + # (self.file) so override it to return the wrapper + def __enter__(self): + self.file.__enter__() + return self + + # NT provides delete-on-close as a primitive, so we don't need + # the wrapper to do anything special. We still use it so that + # file.name is useful (i.e. not "(fdopen)") with NamedTemporaryFile. + if _os.name != 'nt': + # Cache the unlinker so we don't get spurious errors at + # shutdown when the module-level "os" is None'd out. Note + # that this must be referenced as self.unlink, because the + # name TemporaryFileWrapper may also get None'd out before + # __del__ is called. + unlink = _os.unlink + + def close(self): + if not self.close_called: + self.close_called = True + try: + self.file.close() + finally: + if self.delete: + self.unlink(self.name) + + def __del__(self): + self.close() + + # Need to trap __exit__ as well to ensure the file gets + # deleted when used in a with statement + def __exit__(self, exc, value, tb): + result = self.file.__exit__(exc, value, tb) + self.close() + return result + else: + def __exit__(self, exc, value, tb): + self.file.__exit__(exc, value, tb) + + +def NamedTemporaryFile(mode='w+b', bufsize=-1, suffix="", + prefix=template, dir=None, delete=True): + """Create and return a temporary file. + Arguments: + 'prefix', 'suffix', 'dir' -- as for mkstemp. + 'mode' -- the mode argument to os.fdopen (default "w+b"). + 'bufsize' -- the buffer size argument to os.fdopen (default -1). + 'delete' -- whether the file is deleted on close (default True). + The file is created as mkstemp() would do it. + + Returns an object with a file-like interface; the name of the file + is accessible as file.name. The file will be automatically deleted + when it is closed unless the 'delete' argument is set to False. + """ + + if dir is None: + dir = gettempdir() + + if 'b' in mode: + flags = _bin_openflags + else: + flags = _text_openflags + + # Setting O_TEMPORARY in the flags causes the OS to delete + # the file when it is closed. This is only supported by Windows. + if _os.name == 'nt' and delete: + flags |= _os.O_TEMPORARY + + (fd, name) = _mkstemp_inner(dir, prefix, suffix, flags) + try: + file = _os.fdopen(fd, mode, bufsize) + return _TemporaryFileWrapper(file, name, delete) + except: + _os.close(fd) + raise + +if _os.name != 'posix' or _os.sys.platform == 'cygwin': + # On non-POSIX and Cygwin systems, assume that we cannot unlink a file + # while it is open. + TemporaryFile = NamedTemporaryFile + +else: + def TemporaryFile(mode='w+b', bufsize=-1, suffix="", + prefix=template, dir=None): + """Create and return a temporary file. + Arguments: + 'prefix', 'suffix', 'dir' -- as for mkstemp. + 'mode' -- the mode argument to os.fdopen (default "w+b"). + 'bufsize' -- the buffer size argument to os.fdopen (default -1). + The file is created as mkstemp() would do it. + + Returns an object with a file-like interface. The file has no + name, and will cease to exist when it is closed. + """ + + if dir is None: + dir = gettempdir() + + if 'b' in mode: + flags = _bin_openflags + else: + flags = _text_openflags + + (fd, name) = _mkstemp_inner(dir, prefix, suffix, flags) + try: + _os.unlink(name) + return _os.fdopen(fd, mode, bufsize) + except: + _os.close(fd) + raise + +class SpooledTemporaryFile: + """Temporary file wrapper, specialized to switch from + StringIO to a real file when it exceeds a certain size or + when a fileno is needed. + """ + _rolled = False + + def __init__(self, max_size=0, mode='w+b', bufsize=-1, + suffix="", prefix=template, dir=None): + self._file = _StringIO() + self._max_size = max_size + self._rolled = False + self._TemporaryFileArgs = (mode, bufsize, suffix, prefix, dir) + + def _check(self, file): + if self._rolled: return + max_size = self._max_size + if max_size and file.tell() > max_size: + self.rollover() + + def rollover(self): + if self._rolled: return + file = self._file + newfile = self._file = TemporaryFile(*self._TemporaryFileArgs) + del self._TemporaryFileArgs + + newfile.write(file.getvalue()) + newfile.seek(file.tell(), 0) + + self._rolled = True + + # The method caching trick from NamedTemporaryFile + # won't work here, because _file may change from a + # _StringIO instance to a real file. So we list + # all the methods directly. + + # Context management protocol + def __enter__(self): + if self._file.closed: + raise ValueError("Cannot enter context with closed file") + return self + + def __exit__(self, exc, value, tb): + self._file.close() + + # file protocol + def __iter__(self): + return self._file.__iter__() + + def close(self): + self._file.close() + + @property + def closed(self): + return self._file.closed + + def fileno(self): + self.rollover() + return self._file.fileno() + + def flush(self): + self._file.flush() + + def isatty(self): + return self._file.isatty() + + @property + def mode(self): + try: + return self._file.mode + except AttributeError: + return self._TemporaryFileArgs[0] + + @property + def name(self): + try: + return self._file.name + except AttributeError: + return None + + def next(self): + return self._file.next + + def read(self, *args): + return self._file.read(*args) + + def readline(self, *args): + return self._file.readline(*args) + + def readlines(self, *args): + return self._file.readlines(*args) + + def seek(self, *args): + self._file.seek(*args) + + @property + def softspace(self): + return self._file.softspace + + def tell(self): + return self._file.tell() + + def truncate(self): + self._file.truncate() + + def write(self, s): + file = self._file + rv = file.write(s) + self._check(file) + return rv + + def writelines(self, iterable): + file = self._file + rv = file.writelines(iterable) + self._check(file) + return rv + + def xreadlines(self, *args): + if hasattr(self._file, 'xreadlines'): # real file + return iter(self._file) + else: # StringIO() + return iter(self._file.readlines(*args)) diff --git a/textwrap.py b/textwrap.py new file mode 100644 index 0000000..5c2e4fa --- /dev/null +++ b/textwrap.py @@ -0,0 +1,429 @@ +"""Text wrapping and filling. +""" + +# Copyright (C) 1999-2001 Gregory P. Ward. +# Copyright (C) 2002, 2003 Python Software Foundation. +# Written by Greg Ward + +__revision__ = "$Id$" + +import string, re + +try: + _unicode = unicode +except NameError: + # If Python is built without Unicode support, the unicode type + # will not exist. Fake one. + class _unicode(object): + pass + +# Do the right thing with boolean values for all known Python versions +# (so this module can be copied to projects that don't depend on Python +# 2.3, e.g. Optik and Docutils) by uncommenting the block of code below. +#try: +# True, False +#except NameError: +# (True, False) = (1, 0) + +__all__ = ['TextWrapper', 'wrap', 'fill', 'dedent'] + +# Hardcode the recognized whitespace characters to the US-ASCII +# whitespace characters. The main reason for doing this is that in +# ISO-8859-1, 0xa0 is non-breaking whitespace, so in certain locales +# that character winds up in string.whitespace. Respecting +# string.whitespace in those cases would 1) make textwrap treat 0xa0 the +# same as any other whitespace char, which is clearly wrong (it's a +# *non-breaking* space), 2) possibly cause problems with Unicode, +# since 0xa0 is not in range(128). +_whitespace = '\t\n\x0b\x0c\r ' + +class TextWrapper: + """ + Object for wrapping/filling text. The public interface consists of + the wrap() and fill() methods; the other methods are just there for + subclasses to override in order to tweak the default behaviour. + If you want to completely replace the main wrapping algorithm, + you'll probably have to override _wrap_chunks(). + + Several instance attributes control various aspects of wrapping: + width (default: 70) + the maximum width of wrapped lines (unless break_long_words + is false) + initial_indent (default: "") + string that will be prepended to the first line of wrapped + output. Counts towards the line's width. + subsequent_indent (default: "") + string that will be prepended to all lines save the first + of wrapped output; also counts towards each line's width. + expand_tabs (default: true) + Expand tabs in input text to spaces before further processing. + Each tab will become 1 .. 8 spaces, depending on its position in + its line. If false, each tab is treated as a single character. + replace_whitespace (default: true) + Replace all whitespace characters in the input text by spaces + after tab expansion. Note that if expand_tabs is false and + replace_whitespace is true, every tab will be converted to a + single space! + fix_sentence_endings (default: false) + Ensure that sentence-ending punctuation is always followed + by two spaces. Off by default because the algorithm is + (unavoidably) imperfect. + break_long_words (default: true) + Break words longer than 'width'. If false, those words will not + be broken, and some lines might be longer than 'width'. + break_on_hyphens (default: true) + Allow breaking hyphenated words. If true, wrapping will occur + preferably on whitespaces and right after hyphens part of + compound words. + drop_whitespace (default: true) + Drop leading and trailing whitespace from lines. + """ + + whitespace_trans = string.maketrans(_whitespace, ' ' * len(_whitespace)) + + unicode_whitespace_trans = {} + uspace = ord(u' ') + for x in map(ord, _whitespace): + unicode_whitespace_trans[x] = uspace + + # This funky little regex is just the trick for splitting + # text up into word-wrappable chunks. E.g. + # "Hello there -- you goof-ball, use the -b option!" + # splits into + # Hello/ /there/ /--/ /you/ /goof-/ball,/ /use/ /the/ /-b/ /option! + # (after stripping out empty strings). + wordsep_re = re.compile( + r'(\s+|' # any whitespace + r'[^\s\w]*\w+[^0-9\W]-(?=\w+[^0-9\W])|' # hyphenated words + r'(?<=[\w\!\"\'\&\.\,\?])-{2,}(?=\w))') # em-dash + + # This less funky little regex just split on recognized spaces. E.g. + # "Hello there -- you goof-ball, use the -b option!" + # splits into + # Hello/ /there/ /--/ /you/ /goof-ball,/ /use/ /the/ /-b/ /option!/ + wordsep_simple_re = re.compile(r'(\s+)') + + # XXX this is not locale- or charset-aware -- string.lowercase + # is US-ASCII only (and therefore English-only) + sentence_end_re = re.compile(r'[%s]' # lowercase letter + r'[\.\!\?]' # sentence-ending punct. + r'[\"\']?' # optional end-of-quote + r'\Z' # end of chunk + % string.lowercase) + + + def __init__(self, + width=70, + initial_indent="", + subsequent_indent="", + expand_tabs=True, + replace_whitespace=True, + fix_sentence_endings=False, + break_long_words=True, + drop_whitespace=True, + break_on_hyphens=True): + self.width = width + self.initial_indent = initial_indent + self.subsequent_indent = subsequent_indent + self.expand_tabs = expand_tabs + self.replace_whitespace = replace_whitespace + self.fix_sentence_endings = fix_sentence_endings + self.break_long_words = break_long_words + self.drop_whitespace = drop_whitespace + self.break_on_hyphens = break_on_hyphens + + # recompile the regexes for Unicode mode -- done in this clumsy way for + # backwards compatibility because it's rather common to monkey-patch + # the TextWrapper class' wordsep_re attribute. + self.wordsep_re_uni = re.compile(self.wordsep_re.pattern, re.U) + self.wordsep_simple_re_uni = re.compile( + self.wordsep_simple_re.pattern, re.U) + + + # -- Private methods ----------------------------------------------- + # (possibly useful for subclasses to override) + + def _munge_whitespace(self, text): + """_munge_whitespace(text : string) -> string + + Munge whitespace in text: expand tabs and convert all other + whitespace characters to spaces. Eg. " foo\\tbar\\n\\nbaz" + becomes " foo bar baz". + """ + if self.expand_tabs: + text = text.expandtabs() + if self.replace_whitespace: + if isinstance(text, str): + text = text.translate(self.whitespace_trans) + elif isinstance(text, _unicode): + text = text.translate(self.unicode_whitespace_trans) + return text + + + def _split(self, text): + """_split(text : string) -> [string] + + Split the text to wrap into indivisible chunks. Chunks are + not quite the same as words; see _wrap_chunks() for full + details. As an example, the text + Look, goof-ball -- use the -b option! + breaks into the following chunks: + 'Look,', ' ', 'goof-', 'ball', ' ', '--', ' ', + 'use', ' ', 'the', ' ', '-b', ' ', 'option!' + if break_on_hyphens is True, or in: + 'Look,', ' ', 'goof-ball', ' ', '--', ' ', + 'use', ' ', 'the', ' ', '-b', ' ', option!' + otherwise. + """ + if isinstance(text, _unicode): + if self.break_on_hyphens: + pat = self.wordsep_re_uni + else: + pat = self.wordsep_simple_re_uni + else: + if self.break_on_hyphens: + pat = self.wordsep_re + else: + pat = self.wordsep_simple_re + chunks = pat.split(text) + chunks = filter(None, chunks) # remove empty chunks + return chunks + + def _fix_sentence_endings(self, chunks): + """_fix_sentence_endings(chunks : [string]) + + Correct for sentence endings buried in 'chunks'. Eg. when the + original text contains "... foo.\\nBar ...", munge_whitespace() + and split() will convert that to [..., "foo.", " ", "Bar", ...] + which has one too few spaces; this method simply changes the one + space to two. + """ + i = 0 + patsearch = self.sentence_end_re.search + while i < len(chunks)-1: + if chunks[i+1] == " " and patsearch(chunks[i]): + chunks[i+1] = " " + i += 2 + else: + i += 1 + + def _handle_long_word(self, reversed_chunks, cur_line, cur_len, width): + """_handle_long_word(chunks : [string], + cur_line : [string], + cur_len : int, width : int) + + Handle a chunk of text (most likely a word, not whitespace) that + is too long to fit in any line. + """ + # Figure out when indent is larger than the specified width, and make + # sure at least one character is stripped off on every pass + if width < 1: + space_left = 1 + else: + space_left = width - cur_len + + # If we're allowed to break long words, then do so: put as much + # of the next chunk onto the current line as will fit. + if self.break_long_words: + cur_line.append(reversed_chunks[-1][:space_left]) + reversed_chunks[-1] = reversed_chunks[-1][space_left:] + + # Otherwise, we have to preserve the long word intact. Only add + # it to the current line if there's nothing already there -- + # that minimizes how much we violate the width constraint. + elif not cur_line: + cur_line.append(reversed_chunks.pop()) + + # If we're not allowed to break long words, and there's already + # text on the current line, do nothing. Next time through the + # main loop of _wrap_chunks(), we'll wind up here again, but + # cur_len will be zero, so the next line will be entirely + # devoted to the long word that we can't handle right now. + + def _wrap_chunks(self, chunks): + """_wrap_chunks(chunks : [string]) -> [string] + + Wrap a sequence of text chunks and return a list of lines of + length 'self.width' or less. (If 'break_long_words' is false, + some lines may be longer than this.) Chunks correspond roughly + to words and the whitespace between them: each chunk is + indivisible (modulo 'break_long_words'), but a line break can + come between any two chunks. Chunks should not have internal + whitespace; ie. a chunk is either all whitespace or a "word". + Whitespace chunks will be removed from the beginning and end of + lines, but apart from that whitespace is preserved. + """ + lines = [] + if self.width <= 0: + raise ValueError("invalid width %r (must be > 0)" % self.width) + + # Arrange in reverse order so items can be efficiently popped + # from a stack of chucks. + chunks.reverse() + + while chunks: + + # Start the list of chunks that will make up the current line. + # cur_len is just the length of all the chunks in cur_line. + cur_line = [] + cur_len = 0 + + # Figure out which static string will prefix this line. + if lines: + indent = self.subsequent_indent + else: + indent = self.initial_indent + + # Maximum width for this line. + width = self.width - len(indent) + + # First chunk on line is whitespace -- drop it, unless this + # is the very beginning of the text (ie. no lines started yet). + if self.drop_whitespace and chunks[-1].strip() == '' and lines: + del chunks[-1] + + while chunks: + l = len(chunks[-1]) + + # Can at least squeeze this chunk onto the current line. + if cur_len + l <= width: + cur_line.append(chunks.pop()) + cur_len += l + + # Nope, this line is full. + else: + break + + # The current line is full, and the next chunk is too big to + # fit on *any* line (not just this one). + if chunks and len(chunks[-1]) > width: + self._handle_long_word(chunks, cur_line, cur_len, width) + + # If the last chunk on this line is all whitespace, drop it. + if self.drop_whitespace and cur_line and cur_line[-1].strip() == '': + del cur_line[-1] + + # Convert current line back to a string and store it in list + # of all lines (return value). + if cur_line: + lines.append(indent + ''.join(cur_line)) + + return lines + + + # -- Public interface ---------------------------------------------- + + def wrap(self, text): + """wrap(text : string) -> [string] + + Reformat the single paragraph in 'text' so it fits in lines of + no more than 'self.width' columns, and return a list of wrapped + lines. Tabs in 'text' are expanded with string.expandtabs(), + and all other whitespace characters (including newline) are + converted to space. + """ + text = self._munge_whitespace(text) + chunks = self._split(text) + if self.fix_sentence_endings: + self._fix_sentence_endings(chunks) + return self._wrap_chunks(chunks) + + def fill(self, text): + """fill(text : string) -> string + + Reformat the single paragraph in 'text' to fit in lines of no + more than 'self.width' columns, and return a new string + containing the entire wrapped paragraph. + """ + return "\n".join(self.wrap(text)) + + +# -- Convenience interface --------------------------------------------- + +def wrap(text, width=70, **kwargs): + """Wrap a single paragraph of text, returning a list of wrapped lines. + + Reformat the single paragraph in 'text' so it fits in lines of no + more than 'width' columns, and return a list of wrapped lines. By + default, tabs in 'text' are expanded with string.expandtabs(), and + all other whitespace characters (including newline) are converted to + space. See TextWrapper class for available keyword args to customize + wrapping behaviour. + """ + w = TextWrapper(width=width, **kwargs) + return w.wrap(text) + +def fill(text, width=70, **kwargs): + """Fill a single paragraph of text, returning a new string. + + Reformat the single paragraph in 'text' to fit in lines of no more + than 'width' columns, and return a new string containing the entire + wrapped paragraph. As with wrap(), tabs are expanded and other + whitespace characters converted to space. See TextWrapper class for + available keyword args to customize wrapping behaviour. + """ + w = TextWrapper(width=width, **kwargs) + return w.fill(text) + + +# -- Loosely related functionality ------------------------------------- + +_whitespace_only_re = re.compile('^[ \t]+$', re.MULTILINE) +_leading_whitespace_re = re.compile('(^[ \t]*)(?:[^ \t\n])', re.MULTILINE) + +def dedent(text): + """Remove any common leading whitespace from every line in `text`. + + This can be used to make triple-quoted strings line up with the left + edge of the display, while still presenting them in the source code + in indented form. + + Note that tabs and spaces are both treated as whitespace, but they + are not equal: the lines " hello" and "\\thello" are + considered to have no common leading whitespace. (This behaviour is + new in Python 2.5; older versions of this module incorrectly + expanded tabs before searching for common leading whitespace.) + """ + # Look for the longest leading string of spaces and tabs common to + # all lines. + margin = None + text = _whitespace_only_re.sub('', text) + indents = _leading_whitespace_re.findall(text) + for indent in indents: + if margin is None: + margin = indent + + # Current line more deeply indented than previous winner: + # no change (previous winner is still on top). + elif indent.startswith(margin): + pass + + # Current line consistent with and no deeper than previous winner: + # it's the new winner. + elif margin.startswith(indent): + margin = indent + + # Find the largest common whitespace between current line and previous + # winner. + else: + for i, (x, y) in enumerate(zip(margin, indent)): + if x != y: + margin = margin[:i] + break + else: + margin = margin[:len(indent)] + + # sanity check (testing/debugging only) + if 0 and margin: + for line in text.split("\n"): + assert not line or line.startswith(margin), \ + "line = %r, margin = %r" % (line, margin) + + if margin: + text = re.sub(r'(?m)^' + margin, '', text) + return text + +if __name__ == "__main__": + #print dedent("\tfoo\n\tbar") + #print dedent(" \thello there\n \t how are you?") + print dedent("Hello there.\n This is indented.") diff --git a/types.py b/types.py new file mode 100644 index 0000000..d414f54 --- /dev/null +++ b/types.py @@ -0,0 +1,86 @@ +"""Define names for all type symbols known in the standard interpreter. + +Types that are part of optional modules (e.g. array) are not listed. +""" +import sys + +# Iterators in Python aren't a matter of type but of protocol. A large +# and changing number of builtin types implement *some* flavor of +# iterator. Don't check the type! Use hasattr to check for both +# "__iter__" and "next" attributes instead. + +NoneType = type(None) +TypeType = type +ObjectType = object + +IntType = int +LongType = long +FloatType = float +BooleanType = bool +try: + ComplexType = complex +except NameError: + pass + +StringType = str + +# StringTypes is already outdated. Instead of writing "type(x) in +# types.StringTypes", you should use "isinstance(x, basestring)". But +# we keep around for compatibility with Python 2.2. +try: + UnicodeType = unicode + StringTypes = (StringType, UnicodeType) +except NameError: + StringTypes = (StringType,) + +BufferType = buffer + +TupleType = tuple +ListType = list +DictType = DictionaryType = dict + +def _f(): pass +FunctionType = type(_f) +LambdaType = type(lambda: None) # Same as FunctionType +CodeType = type(_f.func_code) + +def _g(): + yield 1 +GeneratorType = type(_g()) + +class _C: + def _m(self): pass +ClassType = type(_C) +UnboundMethodType = type(_C._m) # Same as MethodType +_x = _C() +InstanceType = type(_x) +MethodType = type(_x._m) + +BuiltinFunctionType = type(len) +BuiltinMethodType = type([].append) # Same as BuiltinFunctionType + +ModuleType = type(sys) +FileType = file +XRangeType = xrange + +try: + raise TypeError +except TypeError: + tb = sys.exc_info()[2] + TracebackType = type(tb) + FrameType = type(tb.tb_frame) + del tb + +SliceType = slice +EllipsisType = type(Ellipsis) + +DictProxyType = type(TypeType.__dict__) +NotImplementedType = type(NotImplemented) + +# For Jython, the following two types are identical +GetSetDescriptorType = type(FunctionType.func_code) +MemberDescriptorType = type(FunctionType.func_globals) + +del sys, _f, _g, _C, _x # Not for export + +__all__ = list(n for n in globals() if n[:1] != '_') diff --git a/urllib.py b/urllib.py new file mode 100644 index 0000000..ccb0574 --- /dev/null +++ b/urllib.py @@ -0,0 +1,1637 @@ +"""Open an arbitrary URL. + +See the following document for more info on URLs: +"Names and Addresses, URIs, URLs, URNs, URCs", at +http://www.w3.org/pub/WWW/Addressing/Overview.html + +See also the HTTP spec (from which the error codes are derived): +"HTTP - Hypertext Transfer Protocol", at +http://www.w3.org/pub/WWW/Protocols/ + +Related standards and specs: +- RFC1808: the "relative URL" spec. (authoritative status) +- RFC1738 - the "URL standard". (authoritative status) +- RFC1630 - the "URI spec". (informational status) + +The object returned by URLopener().open(file) will differ per +protocol. All you know is that is has methods read(), readline(), +readlines(), fileno(), close() and info(). The read*(), fileno() +and close() methods work like those of open files. +The info() method returns a mimetools.Message object which can be +used to query various info about the object, if available. +(mimetools.Message objects are queried with the getheader() method.) +""" + +import string +import socket +import os +import time +import sys +import base64 +import re + +from urlparse import urljoin as basejoin + +__all__ = ["urlopen", "URLopener", "FancyURLopener", "urlretrieve", + "urlcleanup", "quote", "quote_plus", "unquote", "unquote_plus", + "urlencode", "url2pathname", "pathname2url", "splittag", + "localhost", "thishost", "ftperrors", "basejoin", "unwrap", + "splittype", "splithost", "splituser", "splitpasswd", "splitport", + "splitnport", "splitquery", "splitattr", "splitvalue", + "getproxies"] + +__version__ = '1.17' # XXX This version is not always updated :-( + +MAXFTPCACHE = 10 # Trim the ftp cache beyond this size + +# Helper for non-unix systems +if os.name == 'nt': + from nturl2path import url2pathname, pathname2url +elif os.name == 'riscos': + from rourl2path import url2pathname, pathname2url +else: + def url2pathname(pathname): + """OS-specific conversion from a relative URL of the 'file' scheme + to a file system path; not recommended for general use.""" + return unquote(pathname) + + def pathname2url(pathname): + """OS-specific conversion from a file system path to a relative URL + of the 'file' scheme; not recommended for general use.""" + return quote(pathname) + +# This really consists of two pieces: +# (1) a class which handles opening of all sorts of URLs +# (plus assorted utilities etc.) +# (2) a set of functions for parsing URLs +# XXX Should these be separated out into different modules? + + +# Shortcut for basic usage +_urlopener = None +def urlopen(url, data=None, proxies=None, context=None): + """Create a file-like object for the specified URL to read from.""" + from warnings import warnpy3k + warnpy3k("urllib.urlopen() has been removed in Python 3.0 in " + "favor of urllib2.urlopen()", stacklevel=2) + + global _urlopener + if proxies is not None or context is not None: + opener = FancyURLopener(proxies=proxies, context=context) + elif not _urlopener: + opener = FancyURLopener() + _urlopener = opener + else: + opener = _urlopener + if data is None: + return opener.open(url) + else: + return opener.open(url, data) +def urlretrieve(url, filename=None, reporthook=None, data=None, context=None): + global _urlopener + if context is not None: + opener = FancyURLopener(context=context) + elif not _urlopener: + _urlopener = opener = FancyURLopener() + else: + opener = _urlopener + return opener.retrieve(url, filename, reporthook, data) +def urlcleanup(): + if _urlopener: + _urlopener.cleanup() + _safe_quoters.clear() + ftpcache.clear() + +# check for SSL +try: + import ssl +except: + _have_ssl = False +else: + _have_ssl = True + +# exception raised when downloaded size does not match content-length +class ContentTooShortError(IOError): + def __init__(self, message, content): + IOError.__init__(self, message) + self.content = content + +ftpcache = {} +class URLopener: + """Class to open URLs. + This is a class rather than just a subroutine because we may need + more than one set of global protocol-specific options. + Note -- this is a base class for those who don't want the + automatic handling of errors type 302 (relocated) and 401 + (authorization needed).""" + + __tempfiles = None + + version = "Python-urllib/%s" % __version__ + + # Constructor + def __init__(self, proxies=None, context=None, **x509): + if proxies is None: + proxies = getproxies() + assert hasattr(proxies, 'has_key'), "proxies must be a mapping" + self.proxies = proxies + self.key_file = x509.get('key_file') + self.cert_file = x509.get('cert_file') + self.context = context + self.addheaders = [('User-Agent', self.version)] + self.__tempfiles = [] + self.__unlink = os.unlink # See cleanup() + self.tempcache = None + # Undocumented feature: if you assign {} to tempcache, + # it is used to cache files retrieved with + # self.retrieve(). This is not enabled by default + # since it does not work for changing documents (and I + # haven't got the logic to check expiration headers + # yet). + self.ftpcache = ftpcache + # Undocumented feature: you can use a different + # ftp cache by assigning to the .ftpcache member; + # in case you want logically independent URL openers + # XXX This is not threadsafe. Bah. + + def __del__(self): + self.close() + + def close(self): + self.cleanup() + + def cleanup(self): + # This code sometimes runs when the rest of this module + # has already been deleted, so it can't use any globals + # or import anything. + if self.__tempfiles: + for file in self.__tempfiles: + try: + self.__unlink(file) + except OSError: + pass + del self.__tempfiles[:] + if self.tempcache: + self.tempcache.clear() + + def addheader(self, *args): + """Add a header to be used by the HTTP interface only + e.g. u.addheader('Accept', 'sound/basic')""" + self.addheaders.append(args) + + # External interface + def open(self, fullurl, data=None): + """Use URLopener().open(file) instead of open(file, 'r').""" + fullurl = unwrap(toBytes(fullurl)) + # percent encode url, fixing lame server errors for e.g, like space + # within url paths. + fullurl = quote(fullurl, safe="%/:=&?~#+!$,;'@()*[]|") + if self.tempcache and fullurl in self.tempcache: + filename, headers = self.tempcache[fullurl] + fp = open(filename, 'rb') + return addinfourl(fp, headers, fullurl) + urltype, url = splittype(fullurl) + if not urltype: + urltype = 'file' + if urltype in self.proxies: + proxy = self.proxies[urltype] + urltype, proxyhost = splittype(proxy) + host, selector = splithost(proxyhost) + url = (host, fullurl) # Signal special case to open_*() + else: + proxy = None + name = 'open_' + urltype + self.type = urltype + name = name.replace('-', '_') + if not hasattr(self, name): + if proxy: + return self.open_unknown_proxy(proxy, fullurl, data) + else: + return self.open_unknown(fullurl, data) + try: + if data is None: + return getattr(self, name)(url) + else: + return getattr(self, name)(url, data) + except socket.error, msg: + raise IOError, ('socket error', msg), sys.exc_info()[2] + + def open_unknown(self, fullurl, data=None): + """Overridable interface to open unknown URL type.""" + type, url = splittype(fullurl) + raise IOError, ('url error', 'unknown url type', type) + + def open_unknown_proxy(self, proxy, fullurl, data=None): + """Overridable interface to open unknown URL type.""" + type, url = splittype(fullurl) + raise IOError, ('url error', 'invalid proxy for %s' % type, proxy) + + # External interface + def retrieve(self, url, filename=None, reporthook=None, data=None): + """retrieve(url) returns (filename, headers) for a local object + or (tempfilename, headers) for a remote object.""" + url = unwrap(toBytes(url)) + if self.tempcache and url in self.tempcache: + return self.tempcache[url] + type, url1 = splittype(url) + if filename is None and (not type or type == 'file'): + try: + fp = self.open_local_file(url1) + hdrs = fp.info() + fp.close() + return url2pathname(splithost(url1)[1]), hdrs + except IOError: + pass + fp = self.open(url, data) + try: + headers = fp.info() + if filename: + tfp = open(filename, 'wb') + else: + import tempfile + garbage, path = splittype(url) + garbage, path = splithost(path or "") + path, garbage = splitquery(path or "") + path, garbage = splitattr(path or "") + suffix = os.path.splitext(path)[1] + (fd, filename) = tempfile.mkstemp(suffix) + self.__tempfiles.append(filename) + tfp = os.fdopen(fd, 'wb') + try: + result = filename, headers + if self.tempcache is not None: + self.tempcache[url] = result + bs = 1024*8 + size = -1 + read = 0 + blocknum = 0 + if "content-length" in headers: + size = int(headers["Content-Length"]) + if reporthook: + reporthook(blocknum, bs, size) + while 1: + block = fp.read(bs) + if block == "": + break + read += len(block) + tfp.write(block) + blocknum += 1 + if reporthook: + reporthook(blocknum, bs, size) + finally: + tfp.close() + finally: + fp.close() + + # raise exception if actual size does not match content-length header + if size >= 0 and read < size: + raise ContentTooShortError("retrieval incomplete: got only %i out " + "of %i bytes" % (read, size), result) + + return result + + # Each method named open_ knows how to open that type of URL + + def open_http(self, url, data=None): + """Use HTTP protocol.""" + import httplib + user_passwd = None + proxy_passwd= None + if isinstance(url, str): + host, selector = splithost(url) + if host: + user_passwd, host = splituser(host) + host = unquote(host) + realhost = host + else: + host, selector = url + # check whether the proxy contains authorization information + proxy_passwd, host = splituser(host) + # now we proceed with the url we want to obtain + urltype, rest = splittype(selector) + url = rest + user_passwd = None + if urltype.lower() != 'http': + realhost = None + else: + realhost, rest = splithost(rest) + if realhost: + user_passwd, realhost = splituser(realhost) + if user_passwd: + selector = "%s://%s%s" % (urltype, realhost, rest) + if proxy_bypass(realhost): + host = realhost + + #print "proxy via http:", host, selector + if not host: raise IOError, ('http error', 'no host given') + + if proxy_passwd: + proxy_passwd = unquote(proxy_passwd) + proxy_auth = base64.b64encode(proxy_passwd).strip() + else: + proxy_auth = None + + if user_passwd: + user_passwd = unquote(user_passwd) + auth = base64.b64encode(user_passwd).strip() + else: + auth = None + h = httplib.HTTP(host) + if data is not None: + h.putrequest('POST', selector) + h.putheader('Content-Type', 'application/x-www-form-urlencoded') + h.putheader('Content-Length', '%d' % len(data)) + else: + h.putrequest('GET', selector) + if proxy_auth: h.putheader('Proxy-Authorization', 'Basic %s' % proxy_auth) + if auth: h.putheader('Authorization', 'Basic %s' % auth) + if realhost: h.putheader('Host', realhost) + for args in self.addheaders: h.putheader(*args) + h.endheaders(data) + errcode, errmsg, headers = h.getreply() + fp = h.getfile() + if errcode == -1: + if fp: fp.close() + # something went wrong with the HTTP status line + raise IOError, ('http protocol error', 0, + 'got a bad status line', None) + # According to RFC 2616, "2xx" code indicates that the client's + # request was successfully received, understood, and accepted. + if (200 <= errcode < 300): + return addinfourl(fp, headers, "http:" + url, errcode) + else: + if data is None: + return self.http_error(url, fp, errcode, errmsg, headers) + else: + return self.http_error(url, fp, errcode, errmsg, headers, data) + + def http_error(self, url, fp, errcode, errmsg, headers, data=None): + """Handle http errors. + Derived class can override this, or provide specific handlers + named http_error_DDD where DDD is the 3-digit error code.""" + # First check if there's a specific handler for this error + name = 'http_error_%d' % errcode + if hasattr(self, name): + method = getattr(self, name) + if data is None: + result = method(url, fp, errcode, errmsg, headers) + else: + result = method(url, fp, errcode, errmsg, headers, data) + if result: return result + return self.http_error_default(url, fp, errcode, errmsg, headers) + + def http_error_default(self, url, fp, errcode, errmsg, headers): + """Default error handler: close the connection and raise IOError.""" + fp.close() + raise IOError, ('http error', errcode, errmsg, headers) + + if _have_ssl: + def open_https(self, url, data=None): + """Use HTTPS protocol.""" + + import httplib + user_passwd = None + proxy_passwd = None + if isinstance(url, str): + host, selector = splithost(url) + if host: + user_passwd, host = splituser(host) + host = unquote(host) + realhost = host + else: + host, selector = url + # here, we determine, whether the proxy contains authorization information + proxy_passwd, host = splituser(host) + urltype, rest = splittype(selector) + url = rest + user_passwd = None + if urltype.lower() != 'https': + realhost = None + else: + realhost, rest = splithost(rest) + if realhost: + user_passwd, realhost = splituser(realhost) + if user_passwd: + selector = "%s://%s%s" % (urltype, realhost, rest) + #print "proxy via https:", host, selector + if not host: raise IOError, ('https error', 'no host given') + if proxy_passwd: + proxy_passwd = unquote(proxy_passwd) + proxy_auth = base64.b64encode(proxy_passwd).strip() + else: + proxy_auth = None + if user_passwd: + user_passwd = unquote(user_passwd) + auth = base64.b64encode(user_passwd).strip() + else: + auth = None + h = httplib.HTTPS(host, 0, + key_file=self.key_file, + cert_file=self.cert_file, + context=self.context) + if data is not None: + h.putrequest('POST', selector) + h.putheader('Content-Type', + 'application/x-www-form-urlencoded') + h.putheader('Content-Length', '%d' % len(data)) + else: + h.putrequest('GET', selector) + if proxy_auth: h.putheader('Proxy-Authorization', 'Basic %s' % proxy_auth) + if auth: h.putheader('Authorization', 'Basic %s' % auth) + if realhost: h.putheader('Host', realhost) + for args in self.addheaders: h.putheader(*args) + h.endheaders(data) + errcode, errmsg, headers = h.getreply() + fp = h.getfile() + if errcode == -1: + if fp: fp.close() + # something went wrong with the HTTP status line + raise IOError, ('http protocol error', 0, + 'got a bad status line', None) + # According to RFC 2616, "2xx" code indicates that the client's + # request was successfully received, understood, and accepted. + if (200 <= errcode < 300): + return addinfourl(fp, headers, "https:" + url, errcode) + else: + if data is None: + return self.http_error(url, fp, errcode, errmsg, headers) + else: + return self.http_error(url, fp, errcode, errmsg, headers, + data) + + def open_file(self, url): + """Use local file or FTP depending on form of URL.""" + if not isinstance(url, str): + raise IOError, ('file error', 'proxy support for file protocol currently not implemented') + if url[:2] == '//' and url[2:3] != '/' and url[2:12].lower() != 'localhost/': + return self.open_ftp(url) + else: + return self.open_local_file(url) + + def open_local_file(self, url): + """Use local file.""" + import mimetypes, mimetools, email.utils + try: + from cStringIO import StringIO + except ImportError: + from StringIO import StringIO + host, file = splithost(url) + localname = url2pathname(file) + try: + stats = os.stat(localname) + except OSError, e: + raise IOError(e.errno, e.strerror, e.filename) + size = stats.st_size + modified = email.utils.formatdate(stats.st_mtime, usegmt=True) + mtype = mimetypes.guess_type(url)[0] + headers = mimetools.Message(StringIO( + 'Content-Type: %s\nContent-Length: %d\nLast-modified: %s\n' % + (mtype or 'text/plain', size, modified))) + if not host: + urlfile = file + if file[:1] == '/': + urlfile = 'file://' + file + elif file[:2] == './': + raise ValueError("local file url may start with / or file:. Unknown url of type: %s" % url) + return addinfourl(open(localname, 'rb'), + headers, urlfile) + host, port = splitport(host) + if not port \ + and socket.gethostbyname(host) in (localhost(), thishost()): + urlfile = file + if file[:1] == '/': + urlfile = 'file://' + file + return addinfourl(open(localname, 'rb'), + headers, urlfile) + raise IOError, ('local file error', 'not on local host') + + def open_ftp(self, url): + """Use FTP protocol.""" + if not isinstance(url, str): + raise IOError, ('ftp error', 'proxy support for ftp protocol currently not implemented') + import mimetypes, mimetools + try: + from cStringIO import StringIO + except ImportError: + from StringIO import StringIO + host, path = splithost(url) + if not host: raise IOError, ('ftp error', 'no host given') + host, port = splitport(host) + user, host = splituser(host) + if user: user, passwd = splitpasswd(user) + else: passwd = None + host = unquote(host) + user = user or '' + passwd = passwd or '' + host = socket.gethostbyname(host) + if not port: + import ftplib + port = ftplib.FTP_PORT + else: + port = int(port) + path, attrs = splitattr(path) + path = unquote(path) + dirs = path.split('/') + dirs, file = dirs[:-1], dirs[-1] + if dirs and not dirs[0]: dirs = dirs[1:] + if dirs and not dirs[0]: dirs[0] = '/' + key = user, host, port, '/'.join(dirs) + # XXX thread unsafe! + if len(self.ftpcache) > MAXFTPCACHE: + # Prune the cache, rather arbitrarily + for k in self.ftpcache.keys(): + if k != key: + v = self.ftpcache[k] + del self.ftpcache[k] + v.close() + try: + if not key in self.ftpcache: + self.ftpcache[key] = \ + ftpwrapper(user, passwd, host, port, dirs) + if not file: type = 'D' + else: type = 'I' + for attr in attrs: + attr, value = splitvalue(attr) + if attr.lower() == 'type' and \ + value in ('a', 'A', 'i', 'I', 'd', 'D'): + type = value.upper() + (fp, retrlen) = self.ftpcache[key].retrfile(file, type) + mtype = mimetypes.guess_type("ftp:" + url)[0] + headers = "" + if mtype: + headers += "Content-Type: %s\n" % mtype + if retrlen is not None and retrlen >= 0: + headers += "Content-Length: %d\n" % retrlen + headers = mimetools.Message(StringIO(headers)) + return addinfourl(fp, headers, "ftp:" + url) + except ftperrors(), msg: + raise IOError, ('ftp error', msg), sys.exc_info()[2] + + def open_data(self, url, data=None): + """Use "data" URL.""" + if not isinstance(url, str): + raise IOError, ('data error', 'proxy support for data protocol currently not implemented') + # ignore POSTed data + # + # syntax of data URLs: + # dataurl := "data:" [ mediatype ] [ ";base64" ] "," data + # mediatype := [ type "/" subtype ] *( ";" parameter ) + # data := *urlchar + # parameter := attribute "=" value + import mimetools + try: + from cStringIO import StringIO + except ImportError: + from StringIO import StringIO + try: + [type, data] = url.split(',', 1) + except ValueError: + raise IOError, ('data error', 'bad data URL') + if not type: + type = 'text/plain;charset=US-ASCII' + semi = type.rfind(';') + if semi >= 0 and '=' not in type[semi:]: + encoding = type[semi+1:] + type = type[:semi] + else: + encoding = '' + msg = [] + msg.append('Date: %s'%time.strftime('%a, %d %b %Y %H:%M:%S GMT', + time.gmtime(time.time()))) + msg.append('Content-type: %s' % type) + if encoding == 'base64': + data = base64.decodestring(data) + else: + data = unquote(data) + msg.append('Content-Length: %d' % len(data)) + msg.append('') + msg.append(data) + msg = '\n'.join(msg) + f = StringIO(msg) + headers = mimetools.Message(f, 0) + #f.fileno = None # needed for addinfourl + return addinfourl(f, headers, url) + + +class FancyURLopener(URLopener): + """Derived class with handlers for errors we can handle (perhaps).""" + + def __init__(self, *args, **kwargs): + URLopener.__init__(self, *args, **kwargs) + self.auth_cache = {} + self.tries = 0 + self.maxtries = 10 + + def http_error_default(self, url, fp, errcode, errmsg, headers): + """Default error handling -- don't raise an exception.""" + return addinfourl(fp, headers, "http:" + url, errcode) + + def http_error_302(self, url, fp, errcode, errmsg, headers, data=None): + """Error 302 -- relocated (temporarily).""" + self.tries += 1 + if self.maxtries and self.tries >= self.maxtries: + if hasattr(self, "http_error_500"): + meth = self.http_error_500 + else: + meth = self.http_error_default + self.tries = 0 + return meth(url, fp, 500, + "Internal Server Error: Redirect Recursion", headers) + result = self.redirect_internal(url, fp, errcode, errmsg, headers, + data) + self.tries = 0 + return result + + def redirect_internal(self, url, fp, errcode, errmsg, headers, data): + if 'location' in headers: + newurl = headers['location'] + elif 'uri' in headers: + newurl = headers['uri'] + else: + return + fp.close() + # In case the server sent a relative URL, join with original: + newurl = basejoin(self.type + ":" + url, newurl) + + # For security reasons we do not allow redirects to protocols + # other than HTTP, HTTPS or FTP. + newurl_lower = newurl.lower() + if not (newurl_lower.startswith('http://') or + newurl_lower.startswith('https://') or + newurl_lower.startswith('ftp://')): + raise IOError('redirect error', errcode, + errmsg + " - Redirection to url '%s' is not allowed" % + newurl, + headers) + + return self.open(newurl) + + def http_error_301(self, url, fp, errcode, errmsg, headers, data=None): + """Error 301 -- also relocated (permanently).""" + return self.http_error_302(url, fp, errcode, errmsg, headers, data) + + def http_error_303(self, url, fp, errcode, errmsg, headers, data=None): + """Error 303 -- also relocated (essentially identical to 302).""" + return self.http_error_302(url, fp, errcode, errmsg, headers, data) + + def http_error_307(self, url, fp, errcode, errmsg, headers, data=None): + """Error 307 -- relocated, but turn POST into error.""" + if data is None: + return self.http_error_302(url, fp, errcode, errmsg, headers, data) + else: + return self.http_error_default(url, fp, errcode, errmsg, headers) + + def http_error_401(self, url, fp, errcode, errmsg, headers, data=None): + """Error 401 -- authentication required. + This function supports Basic authentication only.""" + if not 'www-authenticate' in headers: + URLopener.http_error_default(self, url, fp, + errcode, errmsg, headers) + stuff = headers['www-authenticate'] + import re + match = re.match('[ \t]*([^ \t]+)[ \t]+realm="([^"]*)"', stuff) + if not match: + URLopener.http_error_default(self, url, fp, + errcode, errmsg, headers) + scheme, realm = match.groups() + if scheme.lower() != 'basic': + URLopener.http_error_default(self, url, fp, + errcode, errmsg, headers) + name = 'retry_' + self.type + '_basic_auth' + if data is None: + return getattr(self,name)(url, realm) + else: + return getattr(self,name)(url, realm, data) + + def http_error_407(self, url, fp, errcode, errmsg, headers, data=None): + """Error 407 -- proxy authentication required. + This function supports Basic authentication only.""" + if not 'proxy-authenticate' in headers: + URLopener.http_error_default(self, url, fp, + errcode, errmsg, headers) + stuff = headers['proxy-authenticate'] + import re + match = re.match('[ \t]*([^ \t]+)[ \t]+realm="([^"]*)"', stuff) + if not match: + URLopener.http_error_default(self, url, fp, + errcode, errmsg, headers) + scheme, realm = match.groups() + if scheme.lower() != 'basic': + URLopener.http_error_default(self, url, fp, + errcode, errmsg, headers) + name = 'retry_proxy_' + self.type + '_basic_auth' + if data is None: + return getattr(self,name)(url, realm) + else: + return getattr(self,name)(url, realm, data) + + def retry_proxy_http_basic_auth(self, url, realm, data=None): + host, selector = splithost(url) + newurl = 'http://' + host + selector + proxy = self.proxies['http'] + urltype, proxyhost = splittype(proxy) + proxyhost, proxyselector = splithost(proxyhost) + i = proxyhost.find('@') + 1 + proxyhost = proxyhost[i:] + user, passwd = self.get_user_passwd(proxyhost, realm, i) + if not (user or passwd): return None + proxyhost = quote(user, safe='') + ':' + quote(passwd, safe='') + '@' + proxyhost + self.proxies['http'] = 'http://' + proxyhost + proxyselector + if data is None: + return self.open(newurl) + else: + return self.open(newurl, data) + + def retry_proxy_https_basic_auth(self, url, realm, data=None): + host, selector = splithost(url) + newurl = 'https://' + host + selector + proxy = self.proxies['https'] + urltype, proxyhost = splittype(proxy) + proxyhost, proxyselector = splithost(proxyhost) + i = proxyhost.find('@') + 1 + proxyhost = proxyhost[i:] + user, passwd = self.get_user_passwd(proxyhost, realm, i) + if not (user or passwd): return None + proxyhost = quote(user, safe='') + ':' + quote(passwd, safe='') + '@' + proxyhost + self.proxies['https'] = 'https://' + proxyhost + proxyselector + if data is None: + return self.open(newurl) + else: + return self.open(newurl, data) + + def retry_http_basic_auth(self, url, realm, data=None): + host, selector = splithost(url) + i = host.find('@') + 1 + host = host[i:] + user, passwd = self.get_user_passwd(host, realm, i) + if not (user or passwd): return None + host = quote(user, safe='') + ':' + quote(passwd, safe='') + '@' + host + newurl = 'http://' + host + selector + if data is None: + return self.open(newurl) + else: + return self.open(newurl, data) + + def retry_https_basic_auth(self, url, realm, data=None): + host, selector = splithost(url) + i = host.find('@') + 1 + host = host[i:] + user, passwd = self.get_user_passwd(host, realm, i) + if not (user or passwd): return None + host = quote(user, safe='') + ':' + quote(passwd, safe='') + '@' + host + newurl = 'https://' + host + selector + if data is None: + return self.open(newurl) + else: + return self.open(newurl, data) + + def get_user_passwd(self, host, realm, clear_cache=0): + key = realm + '@' + host.lower() + if key in self.auth_cache: + if clear_cache: + del self.auth_cache[key] + else: + return self.auth_cache[key] + user, passwd = self.prompt_user_passwd(host, realm) + if user or passwd: self.auth_cache[key] = (user, passwd) + return user, passwd + + def prompt_user_passwd(self, host, realm): + """Override this in a GUI environment!""" + import getpass + try: + user = raw_input("Enter username for %s at %s: " % (realm, + host)) + passwd = getpass.getpass("Enter password for %s in %s at %s: " % + (user, realm, host)) + return user, passwd + except KeyboardInterrupt: + print + return None, None + + +# Utility functions + +_localhost = None +def localhost(): + """Return the IP address of the magic hostname 'localhost'.""" + global _localhost + if _localhost is None: + _localhost = socket.gethostbyname('localhost') + return _localhost + +_thishost = None +def thishost(): + """Return the IP address of the current host.""" + global _thishost + if _thishost is None: + try: + _thishost = socket.gethostbyname(socket.gethostname()) + except socket.gaierror: + _thishost = socket.gethostbyname('localhost') + return _thishost + +_ftperrors = None +def ftperrors(): + """Return the set of errors raised by the FTP class.""" + global _ftperrors + if _ftperrors is None: + import ftplib + _ftperrors = ftplib.all_errors + return _ftperrors + +_noheaders = None +def noheaders(): + """Return an empty mimetools.Message object.""" + global _noheaders + if _noheaders is None: + import mimetools + try: + from cStringIO import StringIO + except ImportError: + from StringIO import StringIO + _noheaders = mimetools.Message(StringIO(), 0) + _noheaders.fp.close() # Recycle file descriptor + return _noheaders + + +# Utility classes + +class ftpwrapper: + """Class used by open_ftp() for cache of open FTP connections.""" + + def __init__(self, user, passwd, host, port, dirs, + timeout=socket._GLOBAL_DEFAULT_TIMEOUT, + persistent=True): + self.user = user + self.passwd = passwd + self.host = host + self.port = port + self.dirs = dirs + self.timeout = timeout + self.refcount = 0 + self.keepalive = persistent + try: + self.init() + except: + self.close() + raise + + def init(self): + import ftplib + self.busy = 0 + self.ftp = ftplib.FTP() + self.ftp.connect(self.host, self.port, self.timeout) + self.ftp.login(self.user, self.passwd) + _target = '/'.join(self.dirs) + self.ftp.cwd(_target) + + def retrfile(self, file, type): + import ftplib + self.endtransfer() + if type in ('d', 'D'): cmd = 'TYPE A'; isdir = 1 + else: cmd = 'TYPE ' + type; isdir = 0 + try: + self.ftp.voidcmd(cmd) + except ftplib.all_errors: + self.init() + self.ftp.voidcmd(cmd) + conn = None + if file and not isdir: + # Try to retrieve as a file + try: + cmd = 'RETR ' + file + conn, retrlen = self.ftp.ntransfercmd(cmd) + except ftplib.error_perm, reason: + if str(reason)[:3] != '550': + raise IOError, ('ftp error', reason), sys.exc_info()[2] + if not conn: + # Set transfer mode to ASCII! + self.ftp.voidcmd('TYPE A') + # Try a directory listing. Verify that directory exists. + if file: + pwd = self.ftp.pwd() + try: + try: + self.ftp.cwd(file) + except ftplib.error_perm, reason: + raise IOError, ('ftp error', reason), sys.exc_info()[2] + finally: + self.ftp.cwd(pwd) + cmd = 'LIST ' + file + else: + cmd = 'LIST' + conn, retrlen = self.ftp.ntransfercmd(cmd) + self.busy = 1 + ftpobj = addclosehook(conn.makefile('rb'), self.file_close) + self.refcount += 1 + conn.close() + # Pass back both a suitably decorated object and a retrieval length + return (ftpobj, retrlen) + + def endtransfer(self): + if not self.busy: + return + self.busy = 0 + try: + self.ftp.voidresp() + except ftperrors(): + pass + + def close(self): + self.keepalive = False + if self.refcount <= 0: + self.real_close() + + def file_close(self): + self.endtransfer() + self.refcount -= 1 + if self.refcount <= 0 and not self.keepalive: + self.real_close() + + def real_close(self): + self.endtransfer() + try: + self.ftp.close() + except ftperrors(): + pass + +class addbase: + """Base class for addinfo and addclosehook.""" + + def __init__(self, fp): + self.fp = fp + self.read = self.fp.read + self.readline = self.fp.readline + if hasattr(self.fp, "readlines"): self.readlines = self.fp.readlines + if hasattr(self.fp, "fileno"): + self.fileno = self.fp.fileno + else: + self.fileno = lambda: None + if hasattr(self.fp, "__iter__"): + self.__iter__ = self.fp.__iter__ + if hasattr(self.fp, "next"): + self.next = self.fp.next + + def __repr__(self): + return '<%s at %r whose fp = %r>' % (self.__class__.__name__, + id(self), self.fp) + + def close(self): + self.read = None + self.readline = None + self.readlines = None + self.fileno = None + if self.fp: self.fp.close() + self.fp = None + +class addclosehook(addbase): + """Class to add a close hook to an open file.""" + + def __init__(self, fp, closehook, *hookargs): + addbase.__init__(self, fp) + self.closehook = closehook + self.hookargs = hookargs + + def close(self): + try: + closehook = self.closehook + hookargs = self.hookargs + if closehook: + self.closehook = None + self.hookargs = None + closehook(*hookargs) + finally: + addbase.close(self) + + +class addinfo(addbase): + """class to add an info() method to an open file.""" + + def __init__(self, fp, headers): + addbase.__init__(self, fp) + self.headers = headers + + def info(self): + return self.headers + +class addinfourl(addbase): + """class to add info() and geturl() methods to an open file.""" + + def __init__(self, fp, headers, url, code=None): + addbase.__init__(self, fp) + self.headers = headers + self.url = url + self.code = code + + def info(self): + return self.headers + + def getcode(self): + return self.code + + def geturl(self): + return self.url + + +# Utilities to parse URLs (most of these return None for missing parts): +# unwrap('') --> 'type://host/path' +# splittype('type:opaquestring') --> 'type', 'opaquestring' +# splithost('//host[:port]/path') --> 'host[:port]', '/path' +# splituser('user[:passwd]@host[:port]') --> 'user[:passwd]', 'host[:port]' +# splitpasswd('user:passwd') -> 'user', 'passwd' +# splitport('host:port') --> 'host', 'port' +# splitquery('/path?query') --> '/path', 'query' +# splittag('/path#tag') --> '/path', 'tag' +# splitattr('/path;attr1=value1;attr2=value2;...') -> +# '/path', ['attr1=value1', 'attr2=value2', ...] +# splitvalue('attr=value') --> 'attr', 'value' +# unquote('abc%20def') -> 'abc def' +# quote('abc def') -> 'abc%20def') + +try: + unicode +except NameError: + def _is_unicode(x): + return 0 +else: + def _is_unicode(x): + return isinstance(x, unicode) + +def toBytes(url): + """toBytes(u"URL") --> 'URL'.""" + # Most URL schemes require ASCII. If that changes, the conversion + # can be relaxed + if _is_unicode(url): + try: + url = url.encode("ASCII") + except UnicodeError: + raise UnicodeError("URL " + repr(url) + + " contains non-ASCII characters") + return url + +def unwrap(url): + """unwrap('') --> 'type://host/path'.""" + url = url.strip() + if url[:1] == '<' and url[-1:] == '>': + url = url[1:-1].strip() + if url[:4] == 'URL:': url = url[4:].strip() + return url + +_typeprog = None +def splittype(url): + """splittype('type:opaquestring') --> 'type', 'opaquestring'.""" + global _typeprog + if _typeprog is None: + import re + _typeprog = re.compile('^([^/:]+):') + + match = _typeprog.match(url) + if match: + scheme = match.group(1) + return scheme.lower(), url[len(scheme) + 1:] + return None, url + +_hostprog = None +def splithost(url): + """splithost('//host[:port]/path') --> 'host[:port]', '/path'.""" + global _hostprog + if _hostprog is None: + import re + _hostprog = re.compile('^//([^/?]*)(.*)$') + + match = _hostprog.match(url) + if match: + host_port = match.group(1) + path = match.group(2) + if path and not path.startswith('/'): + path = '/' + path + return host_port, path + return None, url + +_userprog = None +def splituser(host): + """splituser('user[:passwd]@host[:port]') --> 'user[:passwd]', 'host[:port]'.""" + global _userprog + if _userprog is None: + import re + _userprog = re.compile('^(.*)@(.*)$') + + match = _userprog.match(host) + if match: return match.group(1, 2) + return None, host + +_passwdprog = None +def splitpasswd(user): + """splitpasswd('user:passwd') -> 'user', 'passwd'.""" + global _passwdprog + if _passwdprog is None: + import re + _passwdprog = re.compile('^([^:]*):(.*)$',re.S) + + match = _passwdprog.match(user) + if match: return match.group(1, 2) + return user, None + +# splittag('/path#tag') --> '/path', 'tag' +_portprog = None +def splitport(host): + """splitport('host:port') --> 'host', 'port'.""" + global _portprog + if _portprog is None: + import re + _portprog = re.compile('^(.*):([0-9]*)$') + + match = _portprog.match(host) + if match: + host, port = match.groups() + if port: + return host, port + return host, None + +_nportprog = None +def splitnport(host, defport=-1): + """Split host and port, returning numeric port. + Return given default port if no ':' found; defaults to -1. + Return numerical port if a valid number are found after ':'. + Return None if ':' but not a valid number.""" + global _nportprog + if _nportprog is None: + import re + _nportprog = re.compile('^(.*):(.*)$') + + match = _nportprog.match(host) + if match: + host, port = match.group(1, 2) + if port: + try: + nport = int(port) + except ValueError: + nport = None + return host, nport + return host, defport + +_queryprog = None +def splitquery(url): + """splitquery('/path?query') --> '/path', 'query'.""" + global _queryprog + if _queryprog is None: + import re + _queryprog = re.compile('^(.*)\?([^?]*)$') + + match = _queryprog.match(url) + if match: return match.group(1, 2) + return url, None + +_tagprog = None +def splittag(url): + """splittag('/path#tag') --> '/path', 'tag'.""" + global _tagprog + if _tagprog is None: + import re + _tagprog = re.compile('^(.*)#([^#]*)$') + + match = _tagprog.match(url) + if match: return match.group(1, 2) + return url, None + +def splitattr(url): + """splitattr('/path;attr1=value1;attr2=value2;...') -> + '/path', ['attr1=value1', 'attr2=value2', ...].""" + words = url.split(';') + return words[0], words[1:] + +_valueprog = None +def splitvalue(attr): + """splitvalue('attr=value') --> 'attr', 'value'.""" + global _valueprog + if _valueprog is None: + import re + _valueprog = re.compile('^([^=]*)=(.*)$') + + match = _valueprog.match(attr) + if match: return match.group(1, 2) + return attr, None + +# urlparse contains a duplicate of this method to avoid a circular import. If +# you update this method, also update the copy in urlparse. This code +# duplication does not exist in Python3. + +_hexdig = '0123456789ABCDEFabcdef' +_hextochr = dict((a + b, chr(int(a + b, 16))) + for a in _hexdig for b in _hexdig) +_asciire = re.compile('([\x00-\x7f]+)') + +def unquote(s): + """unquote('abc%20def') -> 'abc def'.""" + if _is_unicode(s): + if '%' not in s: + return s + bits = _asciire.split(s) + res = [bits[0]] + append = res.append + for i in range(1, len(bits), 2): + append(unquote(str(bits[i])).decode('latin1')) + append(bits[i + 1]) + return ''.join(res) + + bits = s.split('%') + # fastpath + if len(bits) == 1: + return s + res = [bits[0]] + append = res.append + for item in bits[1:]: + try: + append(_hextochr[item[:2]]) + append(item[2:]) + except KeyError: + append('%') + append(item) + return ''.join(res) + +def unquote_plus(s): + """unquote('%7e/abc+def') -> '~/abc def'""" + s = s.replace('+', ' ') + return unquote(s) + +always_safe = ('ABCDEFGHIJKLMNOPQRSTUVWXYZ' + 'abcdefghijklmnopqrstuvwxyz' + '0123456789' '_.-') +_safe_map = {} +for i, c in zip(xrange(256), str(bytearray(xrange(256)))): + _safe_map[c] = c if (i < 128 and c in always_safe) else '%{:02X}'.format(i) +_safe_quoters = {} + +def quote(s, safe='/'): + """quote('abc def') -> 'abc%20def' + + Each part of a URL, e.g. the path info, the query, etc., has a + different set of reserved characters that must be quoted. + + RFC 2396 Uniform Resource Identifiers (URI): Generic Syntax lists + the following reserved characters. + + reserved = ";" | "/" | "?" | ":" | "@" | "&" | "=" | "+" | + "$" | "," + + Each of these characters is reserved in some component of a URL, + but not necessarily in all of them. + + By default, the quote function is intended for quoting the path + section of a URL. Thus, it will not encode '/'. This character + is reserved, but in typical usage the quote function is being + called on a path where the existing slash characters are used as + reserved characters. + """ + # fastpath + if not s: + if s is None: + raise TypeError('None object cannot be quoted') + return s + cachekey = (safe, always_safe) + try: + (quoter, safe) = _safe_quoters[cachekey] + except KeyError: + safe_map = _safe_map.copy() + safe_map.update([(c, c) for c in safe]) + quoter = safe_map.__getitem__ + safe = always_safe + safe + _safe_quoters[cachekey] = (quoter, safe) + if not s.rstrip(safe): + return s + return ''.join(map(quoter, s)) + +def quote_plus(s, safe=''): + """Quote the query fragment of a URL; replacing ' ' with '+'""" + if ' ' in s: + s = quote(s, safe + ' ') + return s.replace(' ', '+') + return quote(s, safe) + +def urlencode(query, doseq=0): + """Encode a sequence of two-element tuples or dictionary into a URL query string. + + If any values in the query arg are sequences and doseq is true, each + sequence element is converted to a separate parameter. + + If the query arg is a sequence of two-element tuples, the order of the + parameters in the output will match the order of parameters in the + input. + """ + + if hasattr(query,"items"): + # mapping objects + query = query.items() + else: + # it's a bother at times that strings and string-like objects are + # sequences... + try: + # non-sequence items should not work with len() + # non-empty strings will fail this + if len(query) and not isinstance(query[0], tuple): + raise TypeError + # zero-length sequences of all types will get here and succeed, + # but that's a minor nit - since the original implementation + # allowed empty dicts that type of behavior probably should be + # preserved for consistency + except TypeError: + ty,va,tb = sys.exc_info() + raise TypeError, "not a valid non-string sequence or mapping object", tb + + l = [] + if not doseq: + # preserve old behavior + for k, v in query: + k = quote_plus(str(k)) + v = quote_plus(str(v)) + l.append(k + '=' + v) + else: + for k, v in query: + k = quote_plus(str(k)) + if isinstance(v, str): + v = quote_plus(v) + l.append(k + '=' + v) + elif _is_unicode(v): + # is there a reasonable way to convert to ASCII? + # encode generates a string, but "replace" or "ignore" + # lose information and "strict" can raise UnicodeError + v = quote_plus(v.encode("ASCII","replace")) + l.append(k + '=' + v) + else: + try: + # is this a sufficient test for sequence-ness? + len(v) + except TypeError: + # not a sequence + v = quote_plus(str(v)) + l.append(k + '=' + v) + else: + # loop over the sequence + for elt in v: + l.append(k + '=' + quote_plus(str(elt))) + return '&'.join(l) + +# Proxy handling +def getproxies_environment(): + """Return a dictionary of scheme -> proxy server URL mappings. + + Scan the environment for variables named _proxy; + this seems to be the standard convention. If you need a + different way, you can pass a proxies dictionary to the + [Fancy]URLopener constructor. + + """ + proxies = {} + for name, value in os.environ.items(): + name = name.lower() + if value and name[-6:] == '_proxy': + proxies[name[:-6]] = value + return proxies + +def proxy_bypass_environment(host): + """Test if proxies should not be used for a particular host. + + Checks the environment for a variable named no_proxy, which should + be a list of DNS suffixes separated by commas, or '*' for all hosts. + """ + no_proxy = os.environ.get('no_proxy', '') or os.environ.get('NO_PROXY', '') + # '*' is special case for always bypass + if no_proxy == '*': + return 1 + # strip port off host + hostonly, port = splitport(host) + # check if the host ends with any of the DNS suffixes + no_proxy_list = [proxy.strip() for proxy in no_proxy.split(',')] + for name in no_proxy_list: + if name and (hostonly.endswith(name) or host.endswith(name)): + return 1 + # otherwise, don't bypass + return 0 + + +if sys.platform == 'darwin': + from _scproxy import _get_proxy_settings, _get_proxies + + def proxy_bypass_macosx_sysconf(host): + """ + Return True iff this host shouldn't be accessed using a proxy + + This function uses the MacOSX framework SystemConfiguration + to fetch the proxy information. + """ + import re + import socket + from fnmatch import fnmatch + + hostonly, port = splitport(host) + + def ip2num(ipAddr): + parts = ipAddr.split('.') + parts = map(int, parts) + if len(parts) != 4: + parts = (parts + [0, 0, 0, 0])[:4] + return (parts[0] << 24) | (parts[1] << 16) | (parts[2] << 8) | parts[3] + + proxy_settings = _get_proxy_settings() + + # Check for simple host names: + if '.' not in host: + if proxy_settings['exclude_simple']: + return True + + hostIP = None + + for value in proxy_settings.get('exceptions', ()): + # Items in the list are strings like these: *.local, 169.254/16 + if not value: continue + + m = re.match(r"(\d+(?:\.\d+)*)(/\d+)?", value) + if m is not None: + if hostIP is None: + try: + hostIP = socket.gethostbyname(hostonly) + hostIP = ip2num(hostIP) + except socket.error: + continue + + base = ip2num(m.group(1)) + mask = m.group(2) + if mask is None: + mask = 8 * (m.group(1).count('.') + 1) + + else: + mask = int(mask[1:]) + mask = 32 - mask + + if (hostIP >> mask) == (base >> mask): + return True + + elif fnmatch(host, value): + return True + + return False + + def getproxies_macosx_sysconf(): + """Return a dictionary of scheme -> proxy server URL mappings. + + This function uses the MacOSX framework SystemConfiguration + to fetch the proxy information. + """ + return _get_proxies() + + def proxy_bypass(host): + if getproxies_environment(): + return proxy_bypass_environment(host) + else: + return proxy_bypass_macosx_sysconf(host) + + def getproxies(): + return getproxies_environment() or getproxies_macosx_sysconf() + +elif os.name == 'nt': + def getproxies_registry(): + """Return a dictionary of scheme -> proxy server URL mappings. + + Win32 uses the registry to store proxies. + + """ + proxies = {} + try: + import _winreg + except ImportError: + # Std module, so should be around - but you never know! + return proxies + try: + internetSettings = _winreg.OpenKey(_winreg.HKEY_CURRENT_USER, + r'Software\Microsoft\Windows\CurrentVersion\Internet Settings') + proxyEnable = _winreg.QueryValueEx(internetSettings, + 'ProxyEnable')[0] + if proxyEnable: + # Returned as Unicode but problems if not converted to ASCII + proxyServer = str(_winreg.QueryValueEx(internetSettings, + 'ProxyServer')[0]) + if '=' in proxyServer: + # Per-protocol settings + for p in proxyServer.split(';'): + protocol, address = p.split('=', 1) + # See if address has a type:// prefix + import re + if not re.match('^([^/:]+)://', address): + address = '%s://%s' % (protocol, address) + proxies[protocol] = address + else: + # Use one setting for all protocols + if proxyServer[:5] == 'http:': + proxies['http'] = proxyServer + else: + proxies['http'] = 'http://%s' % proxyServer + proxies['https'] = 'https://%s' % proxyServer + proxies['ftp'] = 'ftp://%s' % proxyServer + internetSettings.Close() + except (WindowsError, ValueError, TypeError): + # Either registry key not found etc, or the value in an + # unexpected format. + # proxies already set up to be empty so nothing to do + pass + return proxies + + def getproxies(): + """Return a dictionary of scheme -> proxy server URL mappings. + + Returns settings gathered from the environment, if specified, + or the registry. + + """ + return getproxies_environment() or getproxies_registry() + + def proxy_bypass_registry(host): + try: + import _winreg + import re + except ImportError: + # Std modules, so should be around - but you never know! + return 0 + try: + internetSettings = _winreg.OpenKey(_winreg.HKEY_CURRENT_USER, + r'Software\Microsoft\Windows\CurrentVersion\Internet Settings') + proxyEnable = _winreg.QueryValueEx(internetSettings, + 'ProxyEnable')[0] + proxyOverride = str(_winreg.QueryValueEx(internetSettings, + 'ProxyOverride')[0]) + # ^^^^ Returned as Unicode but problems if not converted to ASCII + except WindowsError: + return 0 + if not proxyEnable or not proxyOverride: + return 0 + # try to make a host list from name and IP address. + rawHost, port = splitport(host) + host = [rawHost] + try: + addr = socket.gethostbyname(rawHost) + if addr != rawHost: + host.append(addr) + except socket.error: + pass + try: + fqdn = socket.getfqdn(rawHost) + if fqdn != rawHost: + host.append(fqdn) + except socket.error: + pass + # make a check value list from the registry entry: replace the + # '' string by the localhost entry and the corresponding + # canonical entry. + proxyOverride = proxyOverride.split(';') + # now check if we match one of the registry values. + for test in proxyOverride: + if test == '': + if '.' not in rawHost: + return 1 + test = test.replace(".", r"\.") # mask dots + test = test.replace("*", r".*") # change glob sequence + test = test.replace("?", r".") # change glob char + for val in host: + # print "%s <--> %s" %( test, val ) + if re.match(test, val, re.I): + return 1 + return 0 + + def proxy_bypass(host): + """Return a dictionary of scheme -> proxy server URL mappings. + + Returns settings gathered from the environment, if specified, + or the registry. + + """ + if getproxies_environment(): + return proxy_bypass_environment(host) + else: + return proxy_bypass_registry(host) + +else: + # By default use environment variables + getproxies = getproxies_environment + proxy_bypass = proxy_bypass_environment + +# Test and time quote() and unquote() +def test1(): + s = '' + for i in range(256): s = s + chr(i) + s = s*4 + t0 = time.time() + qs = quote(s) + uqs = unquote(qs) + t1 = time.time() + if uqs != s: + print 'Wrong!' + print repr(s) + print repr(qs) + print repr(uqs) + print round(t1 - t0, 3), 'sec' + + +def reporthook(blocknum, blocksize, totalsize): + # Report during remote transfers + print "Block number: %d, Block size: %d, Total size: %d" % ( + blocknum, blocksize, totalsize) diff --git a/urllib2.py b/urllib2.py new file mode 100644 index 0000000..9277b1d --- /dev/null +++ b/urllib2.py @@ -0,0 +1,1488 @@ +"""An extensible library for opening URLs using a variety of protocols + +The simplest way to use this module is to call the urlopen function, +which accepts a string containing a URL or a Request object (described +below). It opens the URL and returns the results as file-like +object; the returned object has some extra methods described below. + +The OpenerDirector manages a collection of Handler objects that do +all the actual work. Each Handler implements a particular protocol or +option. The OpenerDirector is a composite object that invokes the +Handlers needed to open the requested URL. For example, the +HTTPHandler performs HTTP GET and POST requests and deals with +non-error returns. The HTTPRedirectHandler automatically deals with +HTTP 301, 302, 303 and 307 redirect errors, and the HTTPDigestAuthHandler +deals with digest authentication. + +urlopen(url, data=None) -- Basic usage is the same as original +urllib. pass the url and optionally data to post to an HTTP URL, and +get a file-like object back. One difference is that you can also pass +a Request instance instead of URL. Raises a URLError (subclass of +IOError); for HTTP errors, raises an HTTPError, which can also be +treated as a valid response. + +build_opener -- Function that creates a new OpenerDirector instance. +Will install the default handlers. Accepts one or more Handlers as +arguments, either instances or Handler classes that it will +instantiate. If one of the argument is a subclass of the default +handler, the argument will be installed instead of the default. + +install_opener -- Installs a new opener as the default opener. + +objects of interest: + +OpenerDirector -- Sets up the User Agent as the Python-urllib client and manages +the Handler classes, while dealing with requests and responses. + +Request -- An object that encapsulates the state of a request. The +state can be as simple as the URL. It can also include extra HTTP +headers, e.g. a User-Agent. + +BaseHandler -- + +exceptions: +URLError -- A subclass of IOError, individual protocols have their own +specific subclass. + +HTTPError -- Also a valid HTTP response, so you can treat an HTTP error +as an exceptional event or valid response. + +internals: +BaseHandler and parent +_call_chain conventions + +Example usage: + +import urllib2 + +# set up authentication info +authinfo = urllib2.HTTPBasicAuthHandler() +authinfo.add_password(realm='PDQ Application', + uri='https://mahler:8092/site-updates.py', + user='klem', + passwd='geheim$parole') + +proxy_support = urllib2.ProxyHandler({"http" : "http://ahad-haam:3128"}) + +# build a new opener that adds authentication and caching FTP handlers +opener = urllib2.build_opener(proxy_support, authinfo, urllib2.CacheFTPHandler) + +# install it +urllib2.install_opener(opener) + +f = urllib2.urlopen('http://www.python.org/') + + +""" + +# XXX issues: +# If an authentication error handler that tries to perform +# authentication for some reason but fails, how should the error be +# signalled? The client needs to know the HTTP error code. But if +# the handler knows that the problem was, e.g., that it didn't know +# that hash algo that requested in the challenge, it would be good to +# pass that information along to the client, too. +# ftp errors aren't handled cleanly +# check digest against correct (i.e. non-apache) implementation + +# Possible extensions: +# complex proxies XXX not sure what exactly was meant by this +# abstract factory for opener + +import base64 +import hashlib +import httplib +import mimetools +import os +import posixpath +import random +import re +import socket +import sys +import time +import urlparse +import bisect +import warnings + +try: + from cStringIO import StringIO +except ImportError: + from StringIO import StringIO + +# check for SSL +try: + import ssl +except ImportError: + _have_ssl = False +else: + _have_ssl = True + +from urllib import (unwrap, unquote, splittype, splithost, quote, + addinfourl, splitport, splittag, toBytes, + splitattr, ftpwrapper, splituser, splitpasswd, splitvalue) + +# support for FileHandler, proxies via environment variables +from urllib import localhost, url2pathname, getproxies, proxy_bypass + +# used in User-Agent header sent +__version__ = sys.version[:3] + +_opener = None +def urlopen(url, data=None, timeout=socket._GLOBAL_DEFAULT_TIMEOUT, + cafile=None, capath=None, cadefault=False, context=None): + global _opener + if cafile or capath or cadefault: + if context is not None: + raise ValueError( + "You can't pass both context and any of cafile, capath, and " + "cadefault" + ) + if not _have_ssl: + raise ValueError('SSL support not available') + context = ssl.create_default_context(purpose=ssl.Purpose.SERVER_AUTH, + cafile=cafile, + capath=capath) + https_handler = HTTPSHandler(context=context) + opener = build_opener(https_handler) + elif context: + https_handler = HTTPSHandler(context=context) + opener = build_opener(https_handler) + elif _opener is None: + _opener = opener = build_opener() + else: + opener = _opener + return opener.open(url, data, timeout) + +def install_opener(opener): + global _opener + _opener = opener + +# do these error classes make sense? +# make sure all of the IOError stuff is overridden. we just want to be +# subtypes. + +class URLError(IOError): + # URLError is a sub-type of IOError, but it doesn't share any of + # the implementation. need to override __init__ and __str__. + # It sets self.args for compatibility with other EnvironmentError + # subclasses, but args doesn't have the typical format with errno in + # slot 0 and strerror in slot 1. This may be better than nothing. + def __init__(self, reason): + self.args = reason, + self.reason = reason + + def __str__(self): + return '' % self.reason + +class HTTPError(URLError, addinfourl): + """Raised when HTTP error occurs, but also acts like non-error return""" + __super_init = addinfourl.__init__ + + def __init__(self, url, code, msg, hdrs, fp): + self.code = code + self.msg = msg + self.hdrs = hdrs + self.fp = fp + self.filename = url + # The addinfourl classes depend on fp being a valid file + # object. In some cases, the HTTPError may not have a valid + # file object. If this happens, the simplest workaround is to + # not initialize the base classes. + if fp is not None: + self.__super_init(fp, hdrs, url, code) + + def __str__(self): + return 'HTTP Error %s: %s' % (self.code, self.msg) + + # since URLError specifies a .reason attribute, HTTPError should also + # provide this attribute. See issue13211 fo discussion. + @property + def reason(self): + return self.msg + + def info(self): + return self.hdrs + +# copied from cookielib.py +_cut_port_re = re.compile(r":\d+$") +def request_host(request): + """Return request-host, as defined by RFC 2965. + + Variation from RFC: returned value is lowercased, for convenient + comparison. + + """ + url = request.get_full_url() + host = urlparse.urlparse(url)[1] + if host == "": + host = request.get_header("Host", "") + + # remove port, if present + host = _cut_port_re.sub("", host, 1) + return host.lower() + +class Request: + + def __init__(self, url, data=None, headers={}, + origin_req_host=None, unverifiable=False): + # unwrap('') --> 'type://host/path' + self.__original = unwrap(url) + self.__original, self.__fragment = splittag(self.__original) + self.type = None + # self.__r_type is what's left after doing the splittype + self.host = None + self.port = None + self._tunnel_host = None + self.data = data + self.headers = {} + for key, value in headers.items(): + self.add_header(key, value) + self.unredirected_hdrs = {} + if origin_req_host is None: + origin_req_host = request_host(self) + self.origin_req_host = origin_req_host + self.unverifiable = unverifiable + + def __getattr__(self, attr): + # XXX this is a fallback mechanism to guard against these + # methods getting called in a non-standard order. this may be + # too complicated and/or unnecessary. + # XXX should the __r_XXX attributes be public? + if attr[:12] == '_Request__r_': + name = attr[12:] + if hasattr(Request, 'get_' + name): + getattr(self, 'get_' + name)() + return getattr(self, attr) + raise AttributeError, attr + + def get_method(self): + if self.has_data(): + return "POST" + else: + return "GET" + + # XXX these helper methods are lame + + def add_data(self, data): + self.data = data + + def has_data(self): + return self.data is not None + + def get_data(self): + return self.data + + def get_full_url(self): + if self.__fragment: + return '%s#%s' % (self.__original, self.__fragment) + else: + return self.__original + + def get_type(self): + if self.type is None: + self.type, self.__r_type = splittype(self.__original) + if self.type is None: + raise ValueError, "unknown url type: %s" % self.__original + return self.type + + def get_host(self): + if self.host is None: + self.host, self.__r_host = splithost(self.__r_type) + if self.host: + self.host = unquote(self.host) + return self.host + + def get_selector(self): + return self.__r_host + + def set_proxy(self, host, type): + if self.type == 'https' and not self._tunnel_host: + self._tunnel_host = self.host + else: + self.type = type + self.__r_host = self.__original + + self.host = host + + def has_proxy(self): + return self.__r_host == self.__original + + def get_origin_req_host(self): + return self.origin_req_host + + def is_unverifiable(self): + return self.unverifiable + + def add_header(self, key, val): + # useful for something like authentication + self.headers[key.capitalize()] = val + + def add_unredirected_header(self, key, val): + # will not be added to a redirected request + self.unredirected_hdrs[key.capitalize()] = val + + def has_header(self, header_name): + return (header_name in self.headers or + header_name in self.unredirected_hdrs) + + def get_header(self, header_name, default=None): + return self.headers.get( + header_name, + self.unredirected_hdrs.get(header_name, default)) + + def header_items(self): + hdrs = self.unredirected_hdrs.copy() + hdrs.update(self.headers) + return hdrs.items() + +class OpenerDirector: + def __init__(self): + client_version = "Python-urllib/%s" % __version__ + self.addheaders = [('User-agent', client_version)] + # self.handlers is retained only for backward compatibility + self.handlers = [] + # manage the individual handlers + self.handle_open = {} + self.handle_error = {} + self.process_response = {} + self.process_request = {} + + def add_handler(self, handler): + if not hasattr(handler, "add_parent"): + raise TypeError("expected BaseHandler instance, got %r" % + type(handler)) + + added = False + for meth in dir(handler): + if meth in ["redirect_request", "do_open", "proxy_open"]: + # oops, coincidental match + continue + + i = meth.find("_") + protocol = meth[:i] + condition = meth[i+1:] + + if condition.startswith("error"): + j = condition.find("_") + i + 1 + kind = meth[j+1:] + try: + kind = int(kind) + except ValueError: + pass + lookup = self.handle_error.get(protocol, {}) + self.handle_error[protocol] = lookup + elif condition == "open": + kind = protocol + lookup = self.handle_open + elif condition == "response": + kind = protocol + lookup = self.process_response + elif condition == "request": + kind = protocol + lookup = self.process_request + else: + continue + + handlers = lookup.setdefault(kind, []) + if handlers: + bisect.insort(handlers, handler) + else: + handlers.append(handler) + added = True + + if added: + bisect.insort(self.handlers, handler) + handler.add_parent(self) + + def close(self): + # Only exists for backwards compatibility. + pass + + def _call_chain(self, chain, kind, meth_name, *args): + # Handlers raise an exception if no one else should try to handle + # the request, or return None if they can't but another handler + # could. Otherwise, they return the response. + handlers = chain.get(kind, ()) + for handler in handlers: + func = getattr(handler, meth_name) + + result = func(*args) + if result is not None: + return result + + def open(self, fullurl, data=None, timeout=socket._GLOBAL_DEFAULT_TIMEOUT): + # accept a URL or a Request object + if isinstance(fullurl, basestring): + req = Request(fullurl, data) + else: + req = fullurl + if data is not None: + req.add_data(data) + + req.timeout = timeout + protocol = req.get_type() + + # pre-process request + meth_name = protocol+"_request" + for processor in self.process_request.get(protocol, []): + meth = getattr(processor, meth_name) + req = meth(req) + + response = self._open(req, data) + + # post-process response + meth_name = protocol+"_response" + for processor in self.process_response.get(protocol, []): + meth = getattr(processor, meth_name) + response = meth(req, response) + + return response + + def _open(self, req, data=None): + result = self._call_chain(self.handle_open, 'default', + 'default_open', req) + if result: + return result + + protocol = req.get_type() + result = self._call_chain(self.handle_open, protocol, protocol + + '_open', req) + if result: + return result + + return self._call_chain(self.handle_open, 'unknown', + 'unknown_open', req) + + def error(self, proto, *args): + if proto in ('http', 'https'): + # XXX http[s] protocols are special-cased + dict = self.handle_error['http'] # https is not different than http + proto = args[2] # YUCK! + meth_name = 'http_error_%s' % proto + http_err = 1 + orig_args = args + else: + dict = self.handle_error + meth_name = proto + '_error' + http_err = 0 + args = (dict, proto, meth_name) + args + result = self._call_chain(*args) + if result: + return result + + if http_err: + args = (dict, 'default', 'http_error_default') + orig_args + return self._call_chain(*args) + +# XXX probably also want an abstract factory that knows when it makes +# sense to skip a superclass in favor of a subclass and when it might +# make sense to include both + +def build_opener(*handlers): + """Create an opener object from a list of handlers. + + The opener will use several default handlers, including support + for HTTP, FTP and when applicable, HTTPS. + + If any of the handlers passed as arguments are subclasses of the + default handlers, the default handlers will not be used. + """ + import types + def isclass(obj): + return isinstance(obj, (types.ClassType, type)) + + opener = OpenerDirector() + default_classes = [ProxyHandler, UnknownHandler, HTTPHandler, + HTTPDefaultErrorHandler, HTTPRedirectHandler, + FTPHandler, FileHandler, HTTPErrorProcessor] + if hasattr(httplib, 'HTTPS'): + default_classes.append(HTTPSHandler) + skip = set() + for klass in default_classes: + for check in handlers: + if isclass(check): + if issubclass(check, klass): + skip.add(klass) + elif isinstance(check, klass): + skip.add(klass) + for klass in skip: + default_classes.remove(klass) + + for klass in default_classes: + opener.add_handler(klass()) + + for h in handlers: + if isclass(h): + h = h() + opener.add_handler(h) + return opener + +class BaseHandler: + handler_order = 500 + + def add_parent(self, parent): + self.parent = parent + + def close(self): + # Only exists for backwards compatibility + pass + + def __lt__(self, other): + if not hasattr(other, "handler_order"): + # Try to preserve the old behavior of having custom classes + # inserted after default ones (works only for custom user + # classes which are not aware of handler_order). + return True + return self.handler_order < other.handler_order + + +class HTTPErrorProcessor(BaseHandler): + """Process HTTP error responses.""" + handler_order = 1000 # after all other processing + + def http_response(self, request, response): + code, msg, hdrs = response.code, response.msg, response.info() + + # According to RFC 2616, "2xx" code indicates that the client's + # request was successfully received, understood, and accepted. + if not (200 <= code < 300): + response = self.parent.error( + 'http', request, response, code, msg, hdrs) + + return response + + https_response = http_response + +class HTTPDefaultErrorHandler(BaseHandler): + def http_error_default(self, req, fp, code, msg, hdrs): + raise HTTPError(req.get_full_url(), code, msg, hdrs, fp) + +class HTTPRedirectHandler(BaseHandler): + # maximum number of redirections to any single URL + # this is needed because of the state that cookies introduce + max_repeats = 4 + # maximum total number of redirections (regardless of URL) before + # assuming we're in a loop + max_redirections = 10 + + def redirect_request(self, req, fp, code, msg, headers, newurl): + """Return a Request or None in response to a redirect. + + This is called by the http_error_30x methods when a + redirection response is received. If a redirection should + take place, return a new Request to allow http_error_30x to + perform the redirect. Otherwise, raise HTTPError if no-one + else should try to handle this url. Return None if you can't + but another Handler might. + """ + m = req.get_method() + if (code in (301, 302, 303, 307) and m in ("GET", "HEAD") + or code in (301, 302, 303) and m == "POST"): + # Strictly (according to RFC 2616), 301 or 302 in response + # to a POST MUST NOT cause a redirection without confirmation + # from the user (of urllib2, in this case). In practice, + # essentially all clients do redirect in this case, so we + # do the same. + # be conciliant with URIs containing a space + newurl = newurl.replace(' ', '%20') + newheaders = dict((k,v) for k,v in req.headers.items() + if k.lower() not in ("content-length", "content-type") + ) + return Request(newurl, + headers=newheaders, + origin_req_host=req.get_origin_req_host(), + unverifiable=True) + else: + raise HTTPError(req.get_full_url(), code, msg, headers, fp) + + # Implementation note: To avoid the server sending us into an + # infinite loop, the request object needs to track what URLs we + # have already seen. Do this by adding a handler-specific + # attribute to the Request object. + def http_error_302(self, req, fp, code, msg, headers): + # Some servers (incorrectly) return multiple Location headers + # (so probably same goes for URI). Use first header. + if 'location' in headers: + newurl = headers.getheaders('location')[0] + elif 'uri' in headers: + newurl = headers.getheaders('uri')[0] + else: + return + + # fix a possible malformed URL + urlparts = urlparse.urlparse(newurl) + if not urlparts.path: + urlparts = list(urlparts) + urlparts[2] = "/" + newurl = urlparse.urlunparse(urlparts) + + newurl = urlparse.urljoin(req.get_full_url(), newurl) + + # For security reasons we do not allow redirects to protocols + # other than HTTP, HTTPS or FTP. + newurl_lower = newurl.lower() + if not (newurl_lower.startswith('http://') or + newurl_lower.startswith('https://') or + newurl_lower.startswith('ftp://')): + raise HTTPError(newurl, code, + msg + " - Redirection to url '%s' is not allowed" % + newurl, + headers, fp) + + # XXX Probably want to forget about the state of the current + # request, although that might interact poorly with other + # handlers that also use handler-specific request attributes + new = self.redirect_request(req, fp, code, msg, headers, newurl) + if new is None: + return + + # loop detection + # .redirect_dict has a key url if url was previously visited. + if hasattr(req, 'redirect_dict'): + visited = new.redirect_dict = req.redirect_dict + if (visited.get(newurl, 0) >= self.max_repeats or + len(visited) >= self.max_redirections): + raise HTTPError(req.get_full_url(), code, + self.inf_msg + msg, headers, fp) + else: + visited = new.redirect_dict = req.redirect_dict = {} + visited[newurl] = visited.get(newurl, 0) + 1 + + # Don't close the fp until we are sure that we won't use it + # with HTTPError. + fp.read() + fp.close() + + return self.parent.open(new, timeout=req.timeout) + + http_error_301 = http_error_303 = http_error_307 = http_error_302 + + inf_msg = "The HTTP server returned a redirect error that would " \ + "lead to an infinite loop.\n" \ + "The last 30x error message was:\n" + + +def _parse_proxy(proxy): + """Return (scheme, user, password, host/port) given a URL or an authority. + + If a URL is supplied, it must have an authority (host:port) component. + According to RFC 3986, having an authority component means the URL must + have two slashes after the scheme: + + >>> _parse_proxy('file:/ftp.example.com/') + Traceback (most recent call last): + ValueError: proxy URL with no authority: 'file:/ftp.example.com/' + + The first three items of the returned tuple may be None. + + Examples of authority parsing: + + >>> _parse_proxy('proxy.example.com') + (None, None, None, 'proxy.example.com') + >>> _parse_proxy('proxy.example.com:3128') + (None, None, None, 'proxy.example.com:3128') + + The authority component may optionally include userinfo (assumed to be + username:password): + + >>> _parse_proxy('joe:password@proxy.example.com') + (None, 'joe', 'password', 'proxy.example.com') + >>> _parse_proxy('joe:password@proxy.example.com:3128') + (None, 'joe', 'password', 'proxy.example.com:3128') + + Same examples, but with URLs instead: + + >>> _parse_proxy('http://proxy.example.com/') + ('http', None, None, 'proxy.example.com') + >>> _parse_proxy('http://proxy.example.com:3128/') + ('http', None, None, 'proxy.example.com:3128') + >>> _parse_proxy('http://joe:password@proxy.example.com/') + ('http', 'joe', 'password', 'proxy.example.com') + >>> _parse_proxy('http://joe:password@proxy.example.com:3128') + ('http', 'joe', 'password', 'proxy.example.com:3128') + + Everything after the authority is ignored: + + >>> _parse_proxy('ftp://joe:password@proxy.example.com/rubbish:3128') + ('ftp', 'joe', 'password', 'proxy.example.com') + + Test for no trailing '/' case: + + >>> _parse_proxy('http://joe:password@proxy.example.com') + ('http', 'joe', 'password', 'proxy.example.com') + + """ + scheme, r_scheme = splittype(proxy) + if not r_scheme.startswith("/"): + # authority + scheme = None + authority = proxy + else: + # URL + if not r_scheme.startswith("//"): + raise ValueError("proxy URL with no authority: %r" % proxy) + # We have an authority, so for RFC 3986-compliant URLs (by ss 3. + # and 3.3.), path is empty or starts with '/' + end = r_scheme.find("/", 2) + if end == -1: + end = None + authority = r_scheme[2:end] + userinfo, hostport = splituser(authority) + if userinfo is not None: + user, password = splitpasswd(userinfo) + else: + user = password = None + return scheme, user, password, hostport + +class ProxyHandler(BaseHandler): + # Proxies must be in front + handler_order = 100 + + def __init__(self, proxies=None): + if proxies is None: + proxies = getproxies() + assert hasattr(proxies, 'has_key'), "proxies must be a mapping" + self.proxies = proxies + for type, url in proxies.items(): + setattr(self, '%s_open' % type, + lambda r, proxy=url, type=type, meth=self.proxy_open: \ + meth(r, proxy, type)) + + def proxy_open(self, req, proxy, type): + orig_type = req.get_type() + proxy_type, user, password, hostport = _parse_proxy(proxy) + + if proxy_type is None: + proxy_type = orig_type + + if req.host and proxy_bypass(req.host): + return None + + if user and password: + user_pass = '%s:%s' % (unquote(user), unquote(password)) + creds = base64.b64encode(user_pass).strip() + req.add_header('Proxy-authorization', 'Basic ' + creds) + hostport = unquote(hostport) + req.set_proxy(hostport, proxy_type) + + if orig_type == proxy_type or orig_type == 'https': + # let other handlers take care of it + return None + else: + # need to start over, because the other handlers don't + # grok the proxy's URL type + # e.g. if we have a constructor arg proxies like so: + # {'http': 'ftp://proxy.example.com'}, we may end up turning + # a request for http://acme.example.com/a into one for + # ftp://proxy.example.com/a + return self.parent.open(req, timeout=req.timeout) + +class HTTPPasswordMgr: + + def __init__(self): + self.passwd = {} + + def add_password(self, realm, uri, user, passwd): + # uri could be a single URI or a sequence + if isinstance(uri, basestring): + uri = [uri] + if not realm in self.passwd: + self.passwd[realm] = {} + for default_port in True, False: + reduced_uri = tuple( + [self.reduce_uri(u, default_port) for u in uri]) + self.passwd[realm][reduced_uri] = (user, passwd) + + def find_user_password(self, realm, authuri): + domains = self.passwd.get(realm, {}) + for default_port in True, False: + reduced_authuri = self.reduce_uri(authuri, default_port) + for uris, authinfo in domains.iteritems(): + for uri in uris: + if self.is_suburi(uri, reduced_authuri): + return authinfo + return None, None + + def reduce_uri(self, uri, default_port=True): + """Accept authority or URI and extract only the authority and path.""" + # note HTTP URLs do not have a userinfo component + parts = urlparse.urlsplit(uri) + if parts[1]: + # URI + scheme = parts[0] + authority = parts[1] + path = parts[2] or '/' + else: + # host or host:port + scheme = None + authority = uri + path = '/' + host, port = splitport(authority) + if default_port and port is None and scheme is not None: + dport = {"http": 80, + "https": 443, + }.get(scheme) + if dport is not None: + authority = "%s:%d" % (host, dport) + return authority, path + + def is_suburi(self, base, test): + """Check if test is below base in a URI tree + + Both args must be URIs in reduced form. + """ + if base == test: + return True + if base[0] != test[0]: + return False + common = posixpath.commonprefix((base[1], test[1])) + if len(common) == len(base[1]): + return True + return False + + +class HTTPPasswordMgrWithDefaultRealm(HTTPPasswordMgr): + + def find_user_password(self, realm, authuri): + user, password = HTTPPasswordMgr.find_user_password(self, realm, + authuri) + if user is not None: + return user, password + return HTTPPasswordMgr.find_user_password(self, None, authuri) + + +class AbstractBasicAuthHandler: + + # XXX this allows for multiple auth-schemes, but will stupidly pick + # the last one with a realm specified. + + # allow for double- and single-quoted realm values + # (single quotes are a violation of the RFC, but appear in the wild) + rx = re.compile('(?:.*,)*[ \t]*([^ \t]+)[ \t]+' + 'realm=(["\']?)([^"\']*)\\2', re.I) + + # XXX could pre-emptively send auth info already accepted (RFC 2617, + # end of section 2, and section 1.2 immediately after "credentials" + # production). + + def __init__(self, password_mgr=None): + if password_mgr is None: + password_mgr = HTTPPasswordMgr() + self.passwd = password_mgr + self.add_password = self.passwd.add_password + + + def http_error_auth_reqed(self, authreq, host, req, headers): + # host may be an authority (without userinfo) or a URL with an + # authority + # XXX could be multiple headers + authreq = headers.get(authreq, None) + + if authreq: + mo = AbstractBasicAuthHandler.rx.search(authreq) + if mo: + scheme, quote, realm = mo.groups() + if quote not in ['"', "'"]: + warnings.warn("Basic Auth Realm was unquoted", + UserWarning, 2) + if scheme.lower() == 'basic': + return self.retry_http_basic_auth(host, req, realm) + + def retry_http_basic_auth(self, host, req, realm): + user, pw = self.passwd.find_user_password(realm, host) + if pw is not None: + raw = "%s:%s" % (user, pw) + auth = 'Basic %s' % base64.b64encode(raw).strip() + if req.get_header(self.auth_header, None) == auth: + return None + req.add_unredirected_header(self.auth_header, auth) + return self.parent.open(req, timeout=req.timeout) + else: + return None + + +class HTTPBasicAuthHandler(AbstractBasicAuthHandler, BaseHandler): + + auth_header = 'Authorization' + + def http_error_401(self, req, fp, code, msg, headers): + url = req.get_full_url() + response = self.http_error_auth_reqed('www-authenticate', + url, req, headers) + return response + + +class ProxyBasicAuthHandler(AbstractBasicAuthHandler, BaseHandler): + + auth_header = 'Proxy-authorization' + + def http_error_407(self, req, fp, code, msg, headers): + # http_error_auth_reqed requires that there is no userinfo component in + # authority. Assume there isn't one, since urllib2 does not (and + # should not, RFC 3986 s. 3.2.1) support requests for URLs containing + # userinfo. + authority = req.get_host() + response = self.http_error_auth_reqed('proxy-authenticate', + authority, req, headers) + return response + + +def randombytes(n): + """Return n random bytes.""" + # Use /dev/urandom if it is available. Fall back to random module + # if not. It might be worthwhile to extend this function to use + # other platform-specific mechanisms for getting random bytes. + if os.path.exists("/dev/urandom"): + f = open("/dev/urandom") + s = f.read(n) + f.close() + return s + else: + L = [chr(random.randrange(0, 256)) for i in range(n)] + return "".join(L) + +class AbstractDigestAuthHandler: + # Digest authentication is specified in RFC 2617. + + # XXX The client does not inspect the Authentication-Info header + # in a successful response. + + # XXX It should be possible to test this implementation against + # a mock server that just generates a static set of challenges. + + # XXX qop="auth-int" supports is shaky + + def __init__(self, passwd=None): + if passwd is None: + passwd = HTTPPasswordMgr() + self.passwd = passwd + self.add_password = self.passwd.add_password + self.retried = 0 + self.nonce_count = 0 + self.last_nonce = None + + def reset_retry_count(self): + self.retried = 0 + + def http_error_auth_reqed(self, auth_header, host, req, headers): + authreq = headers.get(auth_header, None) + if self.retried > 5: + # Don't fail endlessly - if we failed once, we'll probably + # fail a second time. Hm. Unless the Password Manager is + # prompting for the information. Crap. This isn't great + # but it's better than the current 'repeat until recursion + # depth exceeded' approach + raise HTTPError(req.get_full_url(), 401, "digest auth failed", + headers, None) + else: + self.retried += 1 + if authreq: + scheme = authreq.split()[0] + if scheme.lower() == 'digest': + return self.retry_http_digest_auth(req, authreq) + + def retry_http_digest_auth(self, req, auth): + token, challenge = auth.split(' ', 1) + chal = parse_keqv_list(parse_http_list(challenge)) + auth = self.get_authorization(req, chal) + if auth: + auth_val = 'Digest %s' % auth + if req.headers.get(self.auth_header, None) == auth_val: + return None + req.add_unredirected_header(self.auth_header, auth_val) + resp = self.parent.open(req, timeout=req.timeout) + return resp + + def get_cnonce(self, nonce): + # The cnonce-value is an opaque + # quoted string value provided by the client and used by both client + # and server to avoid chosen plaintext attacks, to provide mutual + # authentication, and to provide some message integrity protection. + # This isn't a fabulous effort, but it's probably Good Enough. + dig = hashlib.sha1("%s:%s:%s:%s" % (self.nonce_count, nonce, time.ctime(), + randombytes(8))).hexdigest() + return dig[:16] + + def get_authorization(self, req, chal): + try: + realm = chal['realm'] + nonce = chal['nonce'] + qop = chal.get('qop') + algorithm = chal.get('algorithm', 'MD5') + # mod_digest doesn't send an opaque, even though it isn't + # supposed to be optional + opaque = chal.get('opaque', None) + except KeyError: + return None + + H, KD = self.get_algorithm_impls(algorithm) + if H is None: + return None + + user, pw = self.passwd.find_user_password(realm, req.get_full_url()) + if user is None: + return None + + # XXX not implemented yet + if req.has_data(): + entdig = self.get_entity_digest(req.get_data(), chal) + else: + entdig = None + + A1 = "%s:%s:%s" % (user, realm, pw) + A2 = "%s:%s" % (req.get_method(), + # XXX selector: what about proxies and full urls + req.get_selector()) + if qop == 'auth': + if nonce == self.last_nonce: + self.nonce_count += 1 + else: + self.nonce_count = 1 + self.last_nonce = nonce + + ncvalue = '%08x' % self.nonce_count + cnonce = self.get_cnonce(nonce) + noncebit = "%s:%s:%s:%s:%s" % (nonce, ncvalue, cnonce, qop, H(A2)) + respdig = KD(H(A1), noncebit) + elif qop is None: + respdig = KD(H(A1), "%s:%s" % (nonce, H(A2))) + else: + # XXX handle auth-int. + raise URLError("qop '%s' is not supported." % qop) + + # XXX should the partial digests be encoded too? + + base = 'username="%s", realm="%s", nonce="%s", uri="%s", ' \ + 'response="%s"' % (user, realm, nonce, req.get_selector(), + respdig) + if opaque: + base += ', opaque="%s"' % opaque + if entdig: + base += ', digest="%s"' % entdig + base += ', algorithm="%s"' % algorithm + if qop: + base += ', qop=auth, nc=%s, cnonce="%s"' % (ncvalue, cnonce) + return base + + def get_algorithm_impls(self, algorithm): + # algorithm should be case-insensitive according to RFC2617 + algorithm = algorithm.upper() + # lambdas assume digest modules are imported at the top level + if algorithm == 'MD5': + H = lambda x: hashlib.md5(x).hexdigest() + elif algorithm == 'SHA': + H = lambda x: hashlib.sha1(x).hexdigest() + # XXX MD5-sess + KD = lambda s, d: H("%s:%s" % (s, d)) + return H, KD + + def get_entity_digest(self, data, chal): + # XXX not implemented yet + return None + + +class HTTPDigestAuthHandler(BaseHandler, AbstractDigestAuthHandler): + """An authentication protocol defined by RFC 2069 + + Digest authentication improves on basic authentication because it + does not transmit passwords in the clear. + """ + + auth_header = 'Authorization' + handler_order = 490 # before Basic auth + + def http_error_401(self, req, fp, code, msg, headers): + host = urlparse.urlparse(req.get_full_url())[1] + retry = self.http_error_auth_reqed('www-authenticate', + host, req, headers) + self.reset_retry_count() + return retry + + +class ProxyDigestAuthHandler(BaseHandler, AbstractDigestAuthHandler): + + auth_header = 'Proxy-Authorization' + handler_order = 490 # before Basic auth + + def http_error_407(self, req, fp, code, msg, headers): + host = req.get_host() + retry = self.http_error_auth_reqed('proxy-authenticate', + host, req, headers) + self.reset_retry_count() + return retry + +class AbstractHTTPHandler(BaseHandler): + + def __init__(self, debuglevel=0): + self._debuglevel = debuglevel + + def set_http_debuglevel(self, level): + self._debuglevel = level + + def do_request_(self, request): + host = request.get_host() + if not host: + raise URLError('no host given') + + if request.has_data(): # POST + data = request.get_data() + if not request.has_header('Content-type'): + request.add_unredirected_header( + 'Content-type', + 'application/x-www-form-urlencoded') + if not request.has_header('Content-length'): + request.add_unredirected_header( + 'Content-length', '%d' % len(data)) + + sel_host = host + if request.has_proxy(): + scheme, sel = splittype(request.get_selector()) + sel_host, sel_path = splithost(sel) + + if not request.has_header('Host'): + request.add_unredirected_header('Host', sel_host) + for name, value in self.parent.addheaders: + name = name.capitalize() + if not request.has_header(name): + request.add_unredirected_header(name, value) + + return request + + def do_open(self, http_class, req, **http_conn_args): + """Return an addinfourl object for the request, using http_class. + + http_class must implement the HTTPConnection API from httplib. + The addinfourl return value is a file-like object. It also + has methods and attributes including: + - info(): return a mimetools.Message object for the headers + - geturl(): return the original request URL + - code: HTTP status code + """ + host = req.get_host() + if not host: + raise URLError('no host given') + + # will parse host:port + h = http_class(host, timeout=req.timeout, **http_conn_args) + h.set_debuglevel(self._debuglevel) + + headers = dict(req.unredirected_hdrs) + headers.update(dict((k, v) for k, v in req.headers.items() + if k not in headers)) + + # We want to make an HTTP/1.1 request, but the addinfourl + # class isn't prepared to deal with a persistent connection. + # It will try to read all remaining data from the socket, + # which will block while the server waits for the next request. + # So make sure the connection gets closed after the (only) + # request. + headers["Connection"] = "close" + headers = dict( + (name.title(), val) for name, val in headers.items()) + + if req._tunnel_host: + tunnel_headers = {} + proxy_auth_hdr = "Proxy-Authorization" + if proxy_auth_hdr in headers: + tunnel_headers[proxy_auth_hdr] = headers[proxy_auth_hdr] + # Proxy-Authorization should not be sent to origin + # server. + del headers[proxy_auth_hdr] + h.set_tunnel(req._tunnel_host, headers=tunnel_headers) + + try: + h.request(req.get_method(), req.get_selector(), req.data, headers) + except socket.error, err: # XXX what error? + h.close() + raise URLError(err) + else: + try: + r = h.getresponse(buffering=True) + except TypeError: # buffering kw not supported + r = h.getresponse() + + # Pick apart the HTTPResponse object to get the addinfourl + # object initialized properly. + + # Wrap the HTTPResponse object in socket's file object adapter + # for Windows. That adapter calls recv(), so delegate recv() + # to read(). This weird wrapping allows the returned object to + # have readline() and readlines() methods. + + # XXX It might be better to extract the read buffering code + # out of socket._fileobject() and into a base class. + + r.recv = r.read + fp = socket._fileobject(r, close=True) + + resp = addinfourl(fp, r.msg, req.get_full_url()) + resp.code = r.status + resp.msg = r.reason + return resp + + +class HTTPHandler(AbstractHTTPHandler): + + def http_open(self, req): + return self.do_open(httplib.HTTPConnection, req) + + http_request = AbstractHTTPHandler.do_request_ + +if hasattr(httplib, 'HTTPS'): + class HTTPSHandler(AbstractHTTPHandler): + + def __init__(self, debuglevel=0, context=None): + AbstractHTTPHandler.__init__(self, debuglevel) + self._context = context + + def https_open(self, req): + return self.do_open(httplib.HTTPSConnection, req, + context=self._context) + + https_request = AbstractHTTPHandler.do_request_ + +class HTTPCookieProcessor(BaseHandler): + def __init__(self, cookiejar=None): + import cookielib + if cookiejar is None: + cookiejar = cookielib.CookieJar() + self.cookiejar = cookiejar + + def http_request(self, request): + self.cookiejar.add_cookie_header(request) + return request + + def http_response(self, request, response): + self.cookiejar.extract_cookies(response, request) + return response + + https_request = http_request + https_response = http_response + +class UnknownHandler(BaseHandler): + def unknown_open(self, req): + type = req.get_type() + raise URLError('unknown url type: %s' % type) + +def parse_keqv_list(l): + """Parse list of key=value strings where keys are not duplicated.""" + parsed = {} + for elt in l: + k, v = elt.split('=', 1) + if v[0] == '"' and v[-1] == '"': + v = v[1:-1] + parsed[k] = v + return parsed + +def parse_http_list(s): + """Parse lists as described by RFC 2068 Section 2. + + In particular, parse comma-separated lists where the elements of + the list may include quoted-strings. A quoted-string could + contain a comma. A non-quoted string could have quotes in the + middle. Neither commas nor quotes count if they are escaped. + Only double-quotes count, not single-quotes. + """ + res = [] + part = '' + + escape = quote = False + for cur in s: + if escape: + part += cur + escape = False + continue + if quote: + if cur == '\\': + escape = True + continue + elif cur == '"': + quote = False + part += cur + continue + + if cur == ',': + res.append(part) + part = '' + continue + + if cur == '"': + quote = True + + part += cur + + # append last part + if part: + res.append(part) + + return [part.strip() for part in res] + +def _safe_gethostbyname(host): + try: + return socket.gethostbyname(host) + except socket.gaierror: + return None + +class FileHandler(BaseHandler): + # Use local file or FTP depending on form of URL + def file_open(self, req): + url = req.get_selector() + if url[:2] == '//' and url[2:3] != '/' and (req.host and + req.host != 'localhost'): + req.type = 'ftp' + return self.parent.open(req) + else: + return self.open_local_file(req) + + # names for the localhost + names = None + def get_names(self): + if FileHandler.names is None: + try: + FileHandler.names = tuple( + socket.gethostbyname_ex('localhost')[2] + + socket.gethostbyname_ex(socket.gethostname())[2]) + except socket.gaierror: + FileHandler.names = (socket.gethostbyname('localhost'),) + return FileHandler.names + + # not entirely sure what the rules are here + def open_local_file(self, req): + import email.utils + import mimetypes + host = req.get_host() + filename = req.get_selector() + localfile = url2pathname(filename) + try: + stats = os.stat(localfile) + size = stats.st_size + modified = email.utils.formatdate(stats.st_mtime, usegmt=True) + mtype = mimetypes.guess_type(filename)[0] + headers = mimetools.Message(StringIO( + 'Content-type: %s\nContent-length: %d\nLast-modified: %s\n' % + (mtype or 'text/plain', size, modified))) + if host: + host, port = splitport(host) + if not host or \ + (not port and _safe_gethostbyname(host) in self.get_names()): + if host: + origurl = 'file://' + host + filename + else: + origurl = 'file://' + filename + return addinfourl(open(localfile, 'rb'), headers, origurl) + except OSError, msg: + # urllib2 users shouldn't expect OSErrors coming from urlopen() + raise URLError(msg) + raise URLError('file not on local host') + +class FTPHandler(BaseHandler): + def ftp_open(self, req): + import ftplib + import mimetypes + host = req.get_host() + if not host: + raise URLError('ftp error: no host given') + host, port = splitport(host) + if port is None: + port = ftplib.FTP_PORT + else: + port = int(port) + + # username/password handling + user, host = splituser(host) + if user: + user, passwd = splitpasswd(user) + else: + passwd = None + host = unquote(host) + user = user or '' + passwd = passwd or '' + + try: + host = socket.gethostbyname(host) + except socket.error, msg: + raise URLError(msg) + path, attrs = splitattr(req.get_selector()) + dirs = path.split('/') + dirs = map(unquote, dirs) + dirs, file = dirs[:-1], dirs[-1] + if dirs and not dirs[0]: + dirs = dirs[1:] + try: + fw = self.connect_ftp(user, passwd, host, port, dirs, req.timeout) + type = file and 'I' or 'D' + for attr in attrs: + attr, value = splitvalue(attr) + if attr.lower() == 'type' and \ + value in ('a', 'A', 'i', 'I', 'd', 'D'): + type = value.upper() + fp, retrlen = fw.retrfile(file, type) + headers = "" + mtype = mimetypes.guess_type(req.get_full_url())[0] + if mtype: + headers += "Content-type: %s\n" % mtype + if retrlen is not None and retrlen >= 0: + headers += "Content-length: %d\n" % retrlen + sf = StringIO(headers) + headers = mimetools.Message(sf) + return addinfourl(fp, headers, req.get_full_url()) + except ftplib.all_errors, msg: + raise URLError, ('ftp error: %s' % msg), sys.exc_info()[2] + + def connect_ftp(self, user, passwd, host, port, dirs, timeout): + fw = ftpwrapper(user, passwd, host, port, dirs, timeout, + persistent=False) +## fw.ftp.set_debuglevel(1) + return fw + +class CacheFTPHandler(FTPHandler): + # XXX would be nice to have pluggable cache strategies + # XXX this stuff is definitely not thread safe + def __init__(self): + self.cache = {} + self.timeout = {} + self.soonest = 0 + self.delay = 60 + self.max_conns = 16 + + def setTimeout(self, t): + self.delay = t + + def setMaxConns(self, m): + self.max_conns = m + + def connect_ftp(self, user, passwd, host, port, dirs, timeout): + key = user, host, port, '/'.join(dirs), timeout + if key in self.cache: + self.timeout[key] = time.time() + self.delay + else: + self.cache[key] = ftpwrapper(user, passwd, host, port, dirs, timeout) + self.timeout[key] = time.time() + self.delay + self.check_cache() + return self.cache[key] + + def check_cache(self): + # first check for old ones + t = time.time() + if self.soonest <= t: + for k, v in self.timeout.items(): + if v < t: + self.cache[k].close() + del self.cache[k] + del self.timeout[k] + self.soonest = min(self.timeout.values()) + + # then check the size + if len(self.cache) == self.max_conns: + for k, v in self.timeout.items(): + if v == self.soonest: + del self.cache[k] + del self.timeout[k] + break + self.soonest = min(self.timeout.values()) + + def clear_cache(self): + for conn in self.cache.values(): + conn.close() + self.cache.clear() + self.timeout.clear() diff --git a/urlparse.py b/urlparse.py new file mode 100644 index 0000000..4cd3d67 --- /dev/null +++ b/urlparse.py @@ -0,0 +1,428 @@ +"""Parse (absolute and relative) URLs. + +urlparse module is based upon the following RFC specifications. + +RFC 3986 (STD66): "Uniform Resource Identifiers" by T. Berners-Lee, R. Fielding +and L. Masinter, January 2005. + +RFC 2732 : "Format for Literal IPv6 Addresses in URL's by R.Hinden, B.Carpenter +and L.Masinter, December 1999. + +RFC 2396: "Uniform Resource Identifiers (URI)": Generic Syntax by T. +Berners-Lee, R. Fielding, and L. Masinter, August 1998. + +RFC 2368: "The mailto URL scheme", by P.Hoffman , L Masinter, J. Zwinski, July 1998. + +RFC 1808: "Relative Uniform Resource Locators", by R. Fielding, UC Irvine, June +1995. + +RFC 1738: "Uniform Resource Locators (URL)" by T. Berners-Lee, L. Masinter, M. +McCahill, December 1994 + +RFC 3986 is considered the current standard and any future changes to +urlparse module should conform with it. The urlparse module is +currently not entirely compliant with this RFC due to defacto +scenarios for parsing, and for backward compatibility purposes, some +parsing quirks from older RFCs are retained. The testcases in +test_urlparse.py provides a good indicator of parsing behavior. + +""" + +import re + +__all__ = ["urlparse", "urlunparse", "urljoin", "urldefrag", + "urlsplit", "urlunsplit", "parse_qs", "parse_qsl"] + +# A classification of schemes ('' means apply by default) +uses_relative = ['ftp', 'http', 'gopher', 'nntp', 'imap', + 'wais', 'file', 'https', 'shttp', 'mms', + 'prospero', 'rtsp', 'rtspu', '', 'sftp', + 'svn', 'svn+ssh'] +uses_netloc = ['ftp', 'http', 'gopher', 'nntp', 'telnet', + 'imap', 'wais', 'file', 'mms', 'https', 'shttp', + 'snews', 'prospero', 'rtsp', 'rtspu', 'rsync', '', + 'svn', 'svn+ssh', 'sftp','nfs','git', 'git+ssh'] +uses_params = ['ftp', 'hdl', 'prospero', 'http', 'imap', + 'https', 'shttp', 'rtsp', 'rtspu', 'sip', 'sips', + 'mms', '', 'sftp', 'tel'] + +# These are not actually used anymore, but should stay for backwards +# compatibility. (They are undocumented, but have a public-looking name.) +non_hierarchical = ['gopher', 'hdl', 'mailto', 'news', + 'telnet', 'wais', 'imap', 'snews', 'sip', 'sips'] +uses_query = ['http', 'wais', 'imap', 'https', 'shttp', 'mms', + 'gopher', 'rtsp', 'rtspu', 'sip', 'sips', ''] +uses_fragment = ['ftp', 'hdl', 'http', 'gopher', 'news', + 'nntp', 'wais', 'https', 'shttp', 'snews', + 'file', 'prospero', ''] + +# Characters valid in scheme names +scheme_chars = ('abcdefghijklmnopqrstuvwxyz' + 'ABCDEFGHIJKLMNOPQRSTUVWXYZ' + '0123456789' + '+-.') + +MAX_CACHE_SIZE = 20 +_parse_cache = {} + +def clear_cache(): + """Clear the parse cache.""" + _parse_cache.clear() + + +class ResultMixin(object): + """Shared methods for the parsed result objects.""" + + @property + def username(self): + netloc = self.netloc + if "@" in netloc: + userinfo = netloc.rsplit("@", 1)[0] + if ":" in userinfo: + userinfo = userinfo.split(":", 1)[0] + return userinfo + return None + + @property + def password(self): + netloc = self.netloc + if "@" in netloc: + userinfo = netloc.rsplit("@", 1)[0] + if ":" in userinfo: + return userinfo.split(":", 1)[1] + return None + + @property + def hostname(self): + netloc = self.netloc.split('@')[-1] + if '[' in netloc and ']' in netloc: + return netloc.split(']')[0][1:].lower() + elif ':' in netloc: + return netloc.split(':')[0].lower() + elif netloc == '': + return None + else: + return netloc.lower() + + @property + def port(self): + netloc = self.netloc.split('@')[-1].split(']')[-1] + if ':' in netloc: + port = netloc.split(':')[1] + if port: + port = int(port, 10) + # verify legal port + if (0 <= port <= 65535): + return port + return None + +from collections import namedtuple + +class SplitResult(namedtuple('SplitResult', 'scheme netloc path query fragment'), ResultMixin): + + __slots__ = () + + def geturl(self): + return urlunsplit(self) + + +class ParseResult(namedtuple('ParseResult', 'scheme netloc path params query fragment'), ResultMixin): + + __slots__ = () + + def geturl(self): + return urlunparse(self) + + +def urlparse(url, scheme='', allow_fragments=True): + """Parse a URL into 6 components: + :///;?# + Return a 6-tuple: (scheme, netloc, path, params, query, fragment). + Note that we don't break the components up in smaller bits + (e.g. netloc is a single string) and we don't expand % escapes.""" + tuple = urlsplit(url, scheme, allow_fragments) + scheme, netloc, url, query, fragment = tuple + if scheme in uses_params and ';' in url: + url, params = _splitparams(url) + else: + params = '' + return ParseResult(scheme, netloc, url, params, query, fragment) + +def _splitparams(url): + if '/' in url: + i = url.find(';', url.rfind('/')) + if i < 0: + return url, '' + else: + i = url.find(';') + return url[:i], url[i+1:] + +def _splitnetloc(url, start=0): + delim = len(url) # position of end of domain part of url, default is end + for c in '/?#': # look for delimiters; the order is NOT important + wdelim = url.find(c, start) # find first of this delim + if wdelim >= 0: # if found + delim = min(delim, wdelim) # use earliest delim position + return url[start:delim], url[delim:] # return (domain, rest) + +def urlsplit(url, scheme='', allow_fragments=True): + """Parse a URL into 5 components: + :///?# + Return a 5-tuple: (scheme, netloc, path, query, fragment). + Note that we don't break the components up in smaller bits + (e.g. netloc is a single string) and we don't expand % escapes.""" + allow_fragments = bool(allow_fragments) + key = url, scheme, allow_fragments, type(url), type(scheme) + cached = _parse_cache.get(key, None) + if cached: + return cached + if len(_parse_cache) >= MAX_CACHE_SIZE: # avoid runaway growth + clear_cache() + netloc = query = fragment = '' + i = url.find(':') + if i > 0: + if url[:i] == 'http': # optimize the common case + scheme = url[:i].lower() + url = url[i+1:] + if url[:2] == '//': + netloc, url = _splitnetloc(url, 2) + if (('[' in netloc and ']' not in netloc) or + (']' in netloc and '[' not in netloc)): + raise ValueError("Invalid IPv6 URL") + if allow_fragments and '#' in url: + url, fragment = url.split('#', 1) + if '?' in url: + url, query = url.split('?', 1) + v = SplitResult(scheme, netloc, url, query, fragment) + _parse_cache[key] = v + return v + for c in url[:i]: + if c not in scheme_chars: + break + else: + # make sure "url" is not actually a port number (in which case + # "scheme" is really part of the path) + rest = url[i+1:] + if not rest or any(c not in '0123456789' for c in rest): + # not a port number + scheme, url = url[:i].lower(), rest + + if url[:2] == '//': + netloc, url = _splitnetloc(url, 2) + if (('[' in netloc and ']' not in netloc) or + (']' in netloc and '[' not in netloc)): + raise ValueError("Invalid IPv6 URL") + if allow_fragments and '#' in url: + url, fragment = url.split('#', 1) + if '?' in url: + url, query = url.split('?', 1) + v = SplitResult(scheme, netloc, url, query, fragment) + _parse_cache[key] = v + return v + +def urlunparse(data): + """Put a parsed URL back together again. This may result in a + slightly different, but equivalent URL, if the URL that was parsed + originally had redundant delimiters, e.g. a ? with an empty query + (the draft states that these are equivalent).""" + scheme, netloc, url, params, query, fragment = data + if params: + url = "%s;%s" % (url, params) + return urlunsplit((scheme, netloc, url, query, fragment)) + +def urlunsplit(data): + """Combine the elements of a tuple as returned by urlsplit() into a + complete URL as a string. The data argument can be any five-item iterable. + This may result in a slightly different, but equivalent URL, if the URL that + was parsed originally had unnecessary delimiters (for example, a ? with an + empty query; the RFC states that these are equivalent).""" + scheme, netloc, url, query, fragment = data + if netloc or (scheme and scheme in uses_netloc and url[:2] != '//'): + if url and url[:1] != '/': url = '/' + url + url = '//' + (netloc or '') + url + if scheme: + url = scheme + ':' + url + if query: + url = url + '?' + query + if fragment: + url = url + '#' + fragment + return url + +def urljoin(base, url, allow_fragments=True): + """Join a base URL and a possibly relative URL to form an absolute + interpretation of the latter.""" + if not base: + return url + if not url: + return base + bscheme, bnetloc, bpath, bparams, bquery, bfragment = \ + urlparse(base, '', allow_fragments) + scheme, netloc, path, params, query, fragment = \ + urlparse(url, bscheme, allow_fragments) + if scheme != bscheme or scheme not in uses_relative: + return url + if scheme in uses_netloc: + if netloc: + return urlunparse((scheme, netloc, path, + params, query, fragment)) + netloc = bnetloc + if path[:1] == '/': + return urlunparse((scheme, netloc, path, + params, query, fragment)) + if not path and not params: + path = bpath + params = bparams + if not query: + query = bquery + return urlunparse((scheme, netloc, path, + params, query, fragment)) + segments = bpath.split('/')[:-1] + path.split('/') + # XXX The stuff below is bogus in various ways... + if segments[-1] == '.': + segments[-1] = '' + while '.' in segments: + segments.remove('.') + while 1: + i = 1 + n = len(segments) - 1 + while i < n: + if (segments[i] == '..' + and segments[i-1] not in ('', '..')): + del segments[i-1:i+1] + break + i = i+1 + else: + break + if segments == ['', '..']: + segments[-1] = '' + elif len(segments) >= 2 and segments[-1] == '..': + segments[-2:] = [''] + return urlunparse((scheme, netloc, '/'.join(segments), + params, query, fragment)) + +def urldefrag(url): + """Removes any existing fragment from URL. + + Returns a tuple of the defragmented URL and the fragment. If + the URL contained no fragments, the second element is the + empty string. + """ + if '#' in url: + s, n, p, a, q, frag = urlparse(url) + defrag = urlunparse((s, n, p, a, q, '')) + return defrag, frag + else: + return url, '' + +try: + unicode +except NameError: + def _is_unicode(x): + return 0 +else: + def _is_unicode(x): + return isinstance(x, unicode) + +# unquote method for parse_qs and parse_qsl +# Cannot use directly from urllib as it would create a circular reference +# because urllib uses urlparse methods (urljoin). If you update this function, +# update it also in urllib. This code duplication does not existin in Python3. + +_hexdig = '0123456789ABCDEFabcdef' +_hextochr = dict((a+b, chr(int(a+b,16))) + for a in _hexdig for b in _hexdig) +_asciire = re.compile('([\x00-\x7f]+)') + +def unquote(s): + """unquote('abc%20def') -> 'abc def'.""" + if _is_unicode(s): + if '%' not in s: + return s + bits = _asciire.split(s) + res = [bits[0]] + append = res.append + for i in range(1, len(bits), 2): + append(unquote(str(bits[i])).decode('latin1')) + append(bits[i + 1]) + return ''.join(res) + + bits = s.split('%') + # fastpath + if len(bits) == 1: + return s + res = [bits[0]] + append = res.append + for item in bits[1:]: + try: + append(_hextochr[item[:2]]) + append(item[2:]) + except KeyError: + append('%') + append(item) + return ''.join(res) + +def parse_qs(qs, keep_blank_values=0, strict_parsing=0): + """Parse a query given as a string argument. + + Arguments: + + qs: percent-encoded query string to be parsed + + keep_blank_values: flag indicating whether blank values in + percent-encoded queries should be treated as blank strings. + A true value indicates that blanks should be retained as + blank strings. The default false value indicates that + blank values are to be ignored and treated as if they were + not included. + + strict_parsing: flag indicating what to do with parsing errors. + If false (the default), errors are silently ignored. + If true, errors raise a ValueError exception. + """ + dict = {} + for name, value in parse_qsl(qs, keep_blank_values, strict_parsing): + if name in dict: + dict[name].append(value) + else: + dict[name] = [value] + return dict + +def parse_qsl(qs, keep_blank_values=0, strict_parsing=0): + """Parse a query given as a string argument. + + Arguments: + + qs: percent-encoded query string to be parsed + + keep_blank_values: flag indicating whether blank values in + percent-encoded queries should be treated as blank strings. A + true value indicates that blanks should be retained as blank + strings. The default false value indicates that blank values + are to be ignored and treated as if they were not included. + + strict_parsing: flag indicating what to do with parsing errors. If + false (the default), errors are silently ignored. If true, + errors raise a ValueError exception. + + Returns a list, as G-d intended. + """ + pairs = [s2 for s1 in qs.split('&') for s2 in s1.split(';')] + r = [] + for name_value in pairs: + if not name_value and not strict_parsing: + continue + nv = name_value.split('=', 1) + if len(nv) != 2: + if strict_parsing: + raise ValueError, "bad query field: %r" % (name_value,) + # Handle case of a control-name with no equal sign + if keep_blank_values: + nv.append('') + else: + continue + if len(nv[1]) or keep_blank_values: + name = unquote(nv[0].replace('+', ' ')) + value = unquote(nv[1].replace('+', ' ')) + r.append((name, value)) + + return r diff --git a/warnings.py b/warnings.py new file mode 100644 index 0000000..b0d53aa --- /dev/null +++ b/warnings.py @@ -0,0 +1,422 @@ +"""Python part of the warnings subsystem.""" + +# Note: function level imports should *not* be used +# in this module as it may cause import lock deadlock. +# See bug 683658. +import linecache +import sys +import types + +__all__ = ["warn", "warn_explicit", "showwarning", + "formatwarning", "filterwarnings", "simplefilter", + "resetwarnings", "catch_warnings"] + + +def warnpy3k(message, category=None, stacklevel=1): + """Issue a deprecation warning for Python 3.x related changes. + + Warnings are omitted unless Python is started with the -3 option. + """ + if sys.py3kwarning: + if category is None: + category = DeprecationWarning + warn(message, category, stacklevel+1) + +def _show_warning(message, category, filename, lineno, file=None, line=None): + """Hook to write a warning to a file; replace if you like.""" + if file is None: + file = sys.stderr + if file is None: + # sys.stderr is None - warnings get lost + return + try: + file.write(formatwarning(message, category, filename, lineno, line)) + except (IOError, UnicodeError): + pass # the file (probably stderr) is invalid - this warning gets lost. +# Keep a working version around in case the deprecation of the old API is +# triggered. +showwarning = _show_warning + +def formatwarning(message, category, filename, lineno, line=None): + """Function to format a warning the standard way.""" + try: + unicodetype = unicode + except NameError: + unicodetype = () + try: + message = str(message) + except UnicodeEncodeError: + pass + s = "%s: %s: %s\n" % (lineno, category.__name__, message) + line = linecache.getline(filename, lineno) if line is None else line + if line: + line = line.strip() + if isinstance(s, unicodetype) and isinstance(line, str): + line = unicode(line, 'latin1') + s += " %s\n" % line + if isinstance(s, unicodetype) and isinstance(filename, str): + enc = sys.getfilesystemencoding() + if enc: + try: + filename = unicode(filename, enc) + except UnicodeDecodeError: + pass + s = "%s:%s" % (filename, s) + return s + +def filterwarnings(action, message="", category=Warning, module="", lineno=0, + append=0): + """Insert an entry into the list of warnings filters (at the front). + + 'action' -- one of "error", "ignore", "always", "default", "module", + or "once" + 'message' -- a regex that the warning message must match + 'category' -- a class that the warning must be a subclass of + 'module' -- a regex that the module name must match + 'lineno' -- an integer line number, 0 matches all warnings + 'append' -- if true, append to the list of filters + """ + import re + assert action in ("error", "ignore", "always", "default", "module", + "once"), "invalid action: %r" % (action,) + assert isinstance(message, basestring), "message must be a string" + assert isinstance(category, (type, types.ClassType)), \ + "category must be a class" + assert issubclass(category, Warning), "category must be a Warning subclass" + assert isinstance(module, basestring), "module must be a string" + assert isinstance(lineno, int) and lineno >= 0, \ + "lineno must be an int >= 0" + item = (action, re.compile(message, re.I), category, + re.compile(module), lineno) + if append: + filters.append(item) + else: + filters.insert(0, item) + +def simplefilter(action, category=Warning, lineno=0, append=0): + """Insert a simple entry into the list of warnings filters (at the front). + + A simple filter matches all modules and messages. + 'action' -- one of "error", "ignore", "always", "default", "module", + or "once" + 'category' -- a class that the warning must be a subclass of + 'lineno' -- an integer line number, 0 matches all warnings + 'append' -- if true, append to the list of filters + """ + assert action in ("error", "ignore", "always", "default", "module", + "once"), "invalid action: %r" % (action,) + assert isinstance(lineno, int) and lineno >= 0, \ + "lineno must be an int >= 0" + item = (action, None, category, None, lineno) + if append: + filters.append(item) + else: + filters.insert(0, item) + +def resetwarnings(): + """Clear the list of warning filters, so that no filters are active.""" + filters[:] = [] + +class _OptionError(Exception): + """Exception used by option processing helpers.""" + pass + +# Helper to process -W options passed via sys.warnoptions +def _processoptions(args): + for arg in args: + try: + _setoption(arg) + except _OptionError, msg: + print >>sys.stderr, "Invalid -W option ignored:", msg + +# Helper for _processoptions() +def _setoption(arg): + import re + parts = arg.split(':') + if len(parts) > 5: + raise _OptionError("too many fields (max 5): %r" % (arg,)) + while len(parts) < 5: + parts.append('') + action, message, category, module, lineno = [s.strip() + for s in parts] + action = _getaction(action) + message = re.escape(message) + category = _getcategory(category) + module = re.escape(module) + if module: + module = module + '$' + if lineno: + try: + lineno = int(lineno) + if lineno < 0: + raise ValueError + except (ValueError, OverflowError): + raise _OptionError("invalid lineno %r" % (lineno,)) + else: + lineno = 0 + filterwarnings(action, message, category, module, lineno) + +# Helper for _setoption() +def _getaction(action): + if not action: + return "default" + if action == "all": return "always" # Alias + for a in ('default', 'always', 'ignore', 'module', 'once', 'error'): + if a.startswith(action): + return a + raise _OptionError("invalid action: %r" % (action,)) + +# Helper for _setoption() +def _getcategory(category): + import re + if not category: + return Warning + if re.match("^[a-zA-Z0-9_]+$", category): + try: + cat = eval(category) + except NameError: + raise _OptionError("unknown warning category: %r" % (category,)) + else: + i = category.rfind(".") + module = category[:i] + klass = category[i+1:] + try: + m = __import__(module, None, None, [klass]) + except ImportError: + raise _OptionError("invalid module name: %r" % (module,)) + try: + cat = getattr(m, klass) + except AttributeError: + raise _OptionError("unknown warning category: %r" % (category,)) + if not issubclass(cat, Warning): + raise _OptionError("invalid warning category: %r" % (category,)) + return cat + + +# Code typically replaced by _warnings +def warn(message, category=None, stacklevel=1): + """Issue a warning, or maybe ignore it or raise an exception.""" + # Check if message is already a Warning object + if isinstance(message, Warning): + category = message.__class__ + # Check category argument + if category is None: + category = UserWarning + assert issubclass(category, Warning) + # Get context information + try: + caller = sys._getframe(stacklevel) + except ValueError: + globals = sys.__dict__ + lineno = 1 + else: + globals = caller.f_globals + lineno = caller.f_lineno + if '__name__' in globals: + module = globals['__name__'] + else: + module = "" + filename = globals.get('__file__') + if filename: + fnl = filename.lower() + if fnl.endswith((".pyc", ".pyo")): + filename = filename[:-1] + else: + if module == "__main__": + try: + filename = sys.argv[0] + except AttributeError: + # embedded interpreters don't have sys.argv, see bug #839151 + filename = '__main__' + if not filename: + filename = module + registry = globals.setdefault("__warningregistry__", {}) + warn_explicit(message, category, filename, lineno, module, registry, + globals) + +def warn_explicit(message, category, filename, lineno, + module=None, registry=None, module_globals=None): + lineno = int(lineno) + if module is None: + module = filename or "" + if module[-3:].lower() == ".py": + module = module[:-3] # XXX What about leading pathname? + if registry is None: + registry = {} + if isinstance(message, Warning): + text = str(message) + category = message.__class__ + else: + text = message + message = category(message) + key = (text, category, lineno) + # Quick test for common case + if registry.get(key): + return + # Search the filters + for item in filters: + action, msg, cat, mod, ln = item + if ((msg is None or msg.match(text)) and + issubclass(category, cat) and + (mod is None or mod.match(module)) and + (ln == 0 or lineno == ln)): + break + else: + action = defaultaction + # Early exit actions + if action == "ignore": + registry[key] = 1 + return + + # Prime the linecache for formatting, in case the + # "file" is actually in a zipfile or something. + linecache.getlines(filename, module_globals) + + if action == "error": + raise message + # Other actions + if action == "once": + registry[key] = 1 + oncekey = (text, category) + if onceregistry.get(oncekey): + return + onceregistry[oncekey] = 1 + elif action == "always": + pass + elif action == "module": + registry[key] = 1 + altkey = (text, category, 0) + if registry.get(altkey): + return + registry[altkey] = 1 + elif action == "default": + registry[key] = 1 + else: + # Unrecognized actions are errors + raise RuntimeError( + "Unrecognized action (%r) in warnings.filters:\n %s" % + (action, item)) + # Print message and context + showwarning(message, category, filename, lineno) + + +class WarningMessage(object): + + """Holds the result of a single showwarning() call.""" + + _WARNING_DETAILS = ("message", "category", "filename", "lineno", "file", + "line") + + def __init__(self, message, category, filename, lineno, file=None, + line=None): + local_values = locals() + for attr in self._WARNING_DETAILS: + setattr(self, attr, local_values[attr]) + self._category_name = category.__name__ if category else None + + def __str__(self): + return ("{message : %r, category : %r, filename : %r, lineno : %s, " + "line : %r}" % (self.message, self._category_name, + self.filename, self.lineno, self.line)) + + +class catch_warnings(object): + + """A context manager that copies and restores the warnings filter upon + exiting the context. + + The 'record' argument specifies whether warnings should be captured by a + custom implementation of warnings.showwarning() and be appended to a list + returned by the context manager. Otherwise None is returned by the context + manager. The objects appended to the list are arguments whose attributes + mirror the arguments to showwarning(). + + The 'module' argument is to specify an alternative module to the module + named 'warnings' and imported under that name. This argument is only useful + when testing the warnings module itself. + + """ + + def __init__(self, record=False, module=None): + """Specify whether to record warnings and if an alternative module + should be used other than sys.modules['warnings']. + + For compatibility with Python 3.0, please consider all arguments to be + keyword-only. + + """ + self._record = record + self._module = sys.modules['warnings'] if module is None else module + self._entered = False + + def __repr__(self): + args = [] + if self._record: + args.append("record=True") + if self._module is not sys.modules['warnings']: + args.append("module=%r" % self._module) + name = type(self).__name__ + return "%s(%s)" % (name, ", ".join(args)) + + def __enter__(self): + if self._entered: + raise RuntimeError("Cannot enter %r twice" % self) + self._entered = True + self._filters = self._module.filters + self._module.filters = self._filters[:] + self._showwarning = self._module.showwarning + if self._record: + log = [] + def showwarning(*args, **kwargs): + log.append(WarningMessage(*args, **kwargs)) + self._module.showwarning = showwarning + return log + else: + return None + + def __exit__(self, *exc_info): + if not self._entered: + raise RuntimeError("Cannot exit %r without entering first" % self) + self._module.filters = self._filters + self._module.showwarning = self._showwarning + + +# filters contains a sequence of filter 5-tuples +# The components of the 5-tuple are: +# - an action: error, ignore, always, default, module, or once +# - a compiled regex that must match the warning message +# - a class representing the warning category +# - a compiled regex that must match the module that is being warned +# - a line number for the line being warning, or 0 to mean any line +# If either if the compiled regexs are None, match anything. +_warnings_defaults = False +try: + from _warnings import (filters, default_action, once_registry, + warn, warn_explicit) + defaultaction = default_action + onceregistry = once_registry + _warnings_defaults = True +except ImportError: + filters = [] + defaultaction = "default" + onceregistry = {} + + +# Module initialization +_processoptions(sys.warnoptions) +if not _warnings_defaults: + silence = [ImportWarning, PendingDeprecationWarning] + # Don't silence DeprecationWarning if -3 or -Q was used. + if not sys.py3kwarning and not sys.flags.division_warning: + silence.append(DeprecationWarning) + for cls in silence: + simplefilter("ignore", category=cls) + bytes_warning = sys.flags.bytes_warning + if bytes_warning > 1: + bytes_action = "error" + elif bytes_warning: + bytes_action = "default" + else: + bytes_action = "ignore" + simplefilter(bytes_action, category=BytesWarning, append=1) +del _warnings_defaults diff --git a/weakref.py b/weakref.py new file mode 100644 index 0000000..ca37f87 --- /dev/null +++ b/weakref.py @@ -0,0 +1,458 @@ +"""Weak reference support for Python. + +This module is an implementation of PEP 205: + +http://www.python.org/dev/peps/pep-0205/ +""" + +# Naming convention: Variables named "wr" are weak reference objects; +# they are called this instead of "ref" to avoid name collisions with +# the module-global ref() function imported from _weakref. + +import UserDict + +from _weakref import ( + getweakrefcount, + getweakrefs, + ref, + proxy, + CallableProxyType, + ProxyType, + ReferenceType) + +from _weakrefset import WeakSet, _IterationGuard + +from exceptions import ReferenceError + + +ProxyTypes = (ProxyType, CallableProxyType) + +__all__ = ["ref", "proxy", "getweakrefcount", "getweakrefs", + "WeakKeyDictionary", "ReferenceError", "ReferenceType", "ProxyType", + "CallableProxyType", "ProxyTypes", "WeakValueDictionary", 'WeakSet'] + + +class WeakValueDictionary(UserDict.UserDict): + """Mapping class that references values weakly. + + Entries in the dictionary will be discarded when no strong + reference to the value exists anymore + """ + # We inherit the constructor without worrying about the input + # dictionary; since it uses our .update() method, we get the right + # checks (if the other dictionary is a WeakValueDictionary, + # objects are unwrapped on the way out, and we always wrap on the + # way in). + + def __init__(*args, **kw): + if not args: + raise TypeError("descriptor '__init__' of 'WeakValueDictionary' " + "object needs an argument") + self = args[0] + args = args[1:] + if len(args) > 1: + raise TypeError('expected at most 1 arguments, got %d' % len(args)) + def remove(wr, selfref=ref(self)): + self = selfref() + if self is not None: + if self._iterating: + self._pending_removals.append(wr.key) + else: + del self.data[wr.key] + self._remove = remove + # A list of keys to be removed + self._pending_removals = [] + self._iterating = set() + UserDict.UserDict.__init__(self, *args, **kw) + + def _commit_removals(self): + l = self._pending_removals + d = self.data + # We shouldn't encounter any KeyError, because this method should + # always be called *before* mutating the dict. + while l: + del d[l.pop()] + + def __getitem__(self, key): + o = self.data[key]() + if o is None: + raise KeyError, key + else: + return o + + def __delitem__(self, key): + if self._pending_removals: + self._commit_removals() + del self.data[key] + + def __contains__(self, key): + try: + o = self.data[key]() + except KeyError: + return False + return o is not None + + def has_key(self, key): + try: + o = self.data[key]() + except KeyError: + return False + return o is not None + + def __repr__(self): + return "" % id(self) + + def __setitem__(self, key, value): + if self._pending_removals: + self._commit_removals() + self.data[key] = KeyedRef(value, self._remove, key) + + def clear(self): + if self._pending_removals: + self._commit_removals() + self.data.clear() + + def copy(self): + new = WeakValueDictionary() + for key, wr in self.data.items(): + o = wr() + if o is not None: + new[key] = o + return new + + __copy__ = copy + + def __deepcopy__(self, memo): + from copy import deepcopy + new = self.__class__() + for key, wr in self.data.items(): + o = wr() + if o is not None: + new[deepcopy(key, memo)] = o + return new + + def get(self, key, default=None): + try: + wr = self.data[key] + except KeyError: + return default + else: + o = wr() + if o is None: + # This should only happen + return default + else: + return o + + def items(self): + L = [] + for key, wr in self.data.items(): + o = wr() + if o is not None: + L.append((key, o)) + return L + + def iteritems(self): + with _IterationGuard(self): + for wr in self.data.itervalues(): + value = wr() + if value is not None: + yield wr.key, value + + def iterkeys(self): + with _IterationGuard(self): + for k in self.data.iterkeys(): + yield k + + __iter__ = iterkeys + + def itervaluerefs(self): + """Return an iterator that yields the weak references to the values. + + The references are not guaranteed to be 'live' at the time + they are used, so the result of calling the references needs + to be checked before being used. This can be used to avoid + creating references that will cause the garbage collector to + keep the values around longer than needed. + + """ + with _IterationGuard(self): + for wr in self.data.itervalues(): + yield wr + + def itervalues(self): + with _IterationGuard(self): + for wr in self.data.itervalues(): + obj = wr() + if obj is not None: + yield obj + + def popitem(self): + if self._pending_removals: + self._commit_removals() + while 1: + key, wr = self.data.popitem() + o = wr() + if o is not None: + return key, o + + def pop(self, key, *args): + if self._pending_removals: + self._commit_removals() + try: + o = self.data.pop(key)() + except KeyError: + if args: + return args[0] + raise + if o is None: + raise KeyError, key + else: + return o + + def setdefault(self, key, default=None): + try: + wr = self.data[key] + except KeyError: + if self._pending_removals: + self._commit_removals() + self.data[key] = KeyedRef(default, self._remove, key) + return default + else: + return wr() + + def update(*args, **kwargs): + if not args: + raise TypeError("descriptor 'update' of 'WeakValueDictionary' " + "object needs an argument") + self = args[0] + args = args[1:] + if len(args) > 1: + raise TypeError('expected at most 1 arguments, got %d' % len(args)) + dict = args[0] if args else None + if self._pending_removals: + self._commit_removals() + d = self.data + if dict is not None: + if not hasattr(dict, "items"): + dict = type({})(dict) + for key, o in dict.items(): + d[key] = KeyedRef(o, self._remove, key) + if len(kwargs): + self.update(kwargs) + + def valuerefs(self): + """Return a list of weak references to the values. + + The references are not guaranteed to be 'live' at the time + they are used, so the result of calling the references needs + to be checked before being used. This can be used to avoid + creating references that will cause the garbage collector to + keep the values around longer than needed. + + """ + return self.data.values() + + def values(self): + L = [] + for wr in self.data.values(): + o = wr() + if o is not None: + L.append(o) + return L + + +class KeyedRef(ref): + """Specialized reference that includes a key corresponding to the value. + + This is used in the WeakValueDictionary to avoid having to create + a function object for each key stored in the mapping. A shared + callback object can use the 'key' attribute of a KeyedRef instead + of getting a reference to the key from an enclosing scope. + + """ + + __slots__ = "key", + + def __new__(type, ob, callback, key): + self = ref.__new__(type, ob, callback) + self.key = key + return self + + def __init__(self, ob, callback, key): + super(KeyedRef, self).__init__(ob, callback) + + +class WeakKeyDictionary(UserDict.UserDict): + """ Mapping class that references keys weakly. + + Entries in the dictionary will be discarded when there is no + longer a strong reference to the key. This can be used to + associate additional data with an object owned by other parts of + an application without adding attributes to those objects. This + can be especially useful with objects that override attribute + accesses. + """ + + def __init__(self, dict=None): + self.data = {} + def remove(k, selfref=ref(self)): + self = selfref() + if self is not None: + if self._iterating: + self._pending_removals.append(k) + else: + del self.data[k] + self._remove = remove + # A list of dead weakrefs (keys to be removed) + self._pending_removals = [] + self._iterating = set() + if dict is not None: + self.update(dict) + + def _commit_removals(self): + # NOTE: We don't need to call this method before mutating the dict, + # because a dead weakref never compares equal to a live weakref, + # even if they happened to refer to equal objects. + # However, it means keys may already have been removed. + l = self._pending_removals + d = self.data + while l: + try: + del d[l.pop()] + except KeyError: + pass + + def __delitem__(self, key): + del self.data[ref(key)] + + def __getitem__(self, key): + return self.data[ref(key)] + + def __repr__(self): + return "" % id(self) + + def __setitem__(self, key, value): + self.data[ref(key, self._remove)] = value + + def copy(self): + new = WeakKeyDictionary() + for key, value in self.data.items(): + o = key() + if o is not None: + new[o] = value + return new + + __copy__ = copy + + def __deepcopy__(self, memo): + from copy import deepcopy + new = self.__class__() + for key, value in self.data.items(): + o = key() + if o is not None: + new[o] = deepcopy(value, memo) + return new + + def get(self, key, default=None): + return self.data.get(ref(key),default) + + def has_key(self, key): + try: + wr = ref(key) + except TypeError: + return 0 + return wr in self.data + + def __contains__(self, key): + try: + wr = ref(key) + except TypeError: + return 0 + return wr in self.data + + def items(self): + L = [] + for key, value in self.data.items(): + o = key() + if o is not None: + L.append((o, value)) + return L + + def iteritems(self): + with _IterationGuard(self): + for wr, value in self.data.iteritems(): + key = wr() + if key is not None: + yield key, value + + def iterkeyrefs(self): + """Return an iterator that yields the weak references to the keys. + + The references are not guaranteed to be 'live' at the time + they are used, so the result of calling the references needs + to be checked before being used. This can be used to avoid + creating references that will cause the garbage collector to + keep the keys around longer than needed. + + """ + with _IterationGuard(self): + for wr in self.data.iterkeys(): + yield wr + + def iterkeys(self): + with _IterationGuard(self): + for wr in self.data.iterkeys(): + obj = wr() + if obj is not None: + yield obj + + __iter__ = iterkeys + + def itervalues(self): + with _IterationGuard(self): + for value in self.data.itervalues(): + yield value + + def keyrefs(self): + """Return a list of weak references to the keys. + + The references are not guaranteed to be 'live' at the time + they are used, so the result of calling the references needs + to be checked before being used. This can be used to avoid + creating references that will cause the garbage collector to + keep the keys around longer than needed. + + """ + return self.data.keys() + + def keys(self): + L = [] + for wr in self.data.keys(): + o = wr() + if o is not None: + L.append(o) + return L + + def popitem(self): + while 1: + key, value = self.data.popitem() + o = key() + if o is not None: + return o, value + + def pop(self, key, *args): + return self.data.pop(ref(key), *args) + + def setdefault(self, key, default=None): + return self.data.setdefault(ref(key, self._remove),default) + + def update(self, dict=None, **kwargs): + d = self.data + if dict is not None: + if not hasattr(dict, "items"): + dict = type({})(dict) + for key, value in dict.items(): + d[ref(key, self._remove)] = value + if len(kwargs): + self.update(kwargs) diff --git a/xml2py.py b/xml2py.py new file mode 100644 index 0000000..e108f0a --- /dev/null +++ b/xml2py.py @@ -0,0 +1,121 @@ +##################################################################################### +# +# Copyright (c) Harry Pierson. All rights reserved. +# +# This source code is subject to terms and conditions of the Microsoft Public License. +# A copy of the license can be found at http://opensource.org/licenses/ms-pl.html +# By using this source code in any fashion, you are agreeing to be bound +# by the terms of the Microsoft Public License. +# +# You must not remove this notice, or any other, from this software. +# +##################################################################################### + +import ipypulldom +from System.Xml import XmlNodeType + +class _type_factory(object): + class _type_node(object): + def __init__(self, node): + ty = type(node) + self.name = ty.__name__ + self.namespace = ty.xmlns + + def __init__(self): + self.types = {} + + def find_type(self, node, parent): + def create_type(node, parent): + return type(node.name, (parent,), {'xmlns':node.namespace}) + + if parent not in self.types: + self.types[parent] = {} + + tp = self.types[parent] + if node.name not in tp: + tp[node.name] = [create_type(node, parent)] + + tpn = tp[node.name] + + for t in tpn: + if t.xmlns == node.namespace: + return t + + #if there's no matching namespace type, create one and add it to the list + new_type = create_type(node, parent) + tpn.append(new_type) + return new_type + + def __call__(self, node, parent=object): + if isinstance(node, ipypulldom.XmlNode): + return self.find_type(node, parent) + return self.find_type(self._type_node(node), parent) + + +xtype = _type_factory() + + +def xml2py(nodelist): + + def children(nodelist): + while True: + child = xml2py(nodelist) + if child is None: + break + yield child + + def set_attribute(parent, child): + name = type(child).__name__ + if not hasattr(parent, name): + setattr(parent, name, child) + else: + val = getattr(parent, name) + if isinstance(val, list): + val.append(child) + else: + setattr(parent, name, [val, child]) + + node = nodelist.next() + if node.nodeType == XmlNodeType.EndElement: + return None + + elif node.nodeType == XmlNodeType.Text or node.nodeType == XmlNodeType.CDATA: + return node.value + + elif node.nodeType == XmlNodeType.Element: + + #create a new object type named for the element name + cur = xtype(node)() + cur._nodetype = XmlNodeType.Element + + #collect all the attributes and children in lists + attributes = [xtype(attr, str)(attr.value) for attr in node.attributes] + children = [child for child in children(nodelist)] + + if len(children) == 1 and isinstance(children[0], str): + #fold up elements with a single text node + cur = xtype(cur, str)(children[0]) + cur._nodetype = XmlNodeType.Element + else: + #otherwise, add child elements as properties on the current node + for child in children: + set_attribute(cur, child) + + for attr in attributes: + attr._nodetype = XmlNodeType.Attribute + set_attribute(cur, attr) + + return cur + + +def parse(xml): + return xml2py(ipypulldom.parse(xml)) + +def parseString(xml): + return xml2py(ipypulldom.parseString(xml)) + + +if __name__ == '__main__': + rss = parse('http://feeds.feedburner.com/Devhawk') + for item in rss.channel.item: + print item.title